././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1886272 hdmf-3.1.1/0000755000655200065520000000000000000000000012561 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/Legal.txt0000644000655200065520000000150600000000000014350 0ustar00circlecicircleci“hdmf” Copyright (c) 2017-2021, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Innovation & Partnerships Office at IPO@lbl.gov. NOTICE. This Software was developed under funding from the U.S. Department of Energy and the U.S. Government consequently retains certain rights. As such, the U.S. Government has been granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, distribute copies to the public, prepare derivative works, and perform publicly and display publicly, and to permit other to do so. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/MANIFEST.in0000644000655200065520000000031600000000000014317 0ustar00circlecicircleciinclude license.txt Legal.txt versioneer.py src/hdmf/_version.py src/hdmf/_due.py include requirements.txt requirements-dev.txt requirements-doc.txt requirements-min.txt include test.py tox.ini graft tests ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1886272 hdmf-3.1.1/PKG-INFO0000644000655200065520000001745300000000000013670 0ustar00circlecicircleciMetadata-Version: 2.1 Name: hdmf Version: 3.1.1 Summary: A package for standardizing hierarchical object data Home-page: https://github.com/hdmf-dev/hdmf Author: Andrew Tritt Author-email: ajtritt@lbl.gov License: BSD Keywords: python HDF HDF5 cross-platform open-data data-format open-source open-science reproducible-research Platform: UNKNOWN Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: License :: OSI Approved :: BSD License Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: Intended Audience :: Science/Research Classifier: Operating System :: Microsoft :: Windows Classifier: Operating System :: MacOS Classifier: Operating System :: Unix Classifier: Topic :: Scientific/Engineering :: Medical Science Apps. Requires-Python: >=3.7 Description-Content-Type: text/x-rst; charset=UTF-8 ======================================== The Hierarchical Data Modeling Framework ======================================== The Hierarchical Data Modeling Framework, or *HDMF*, is a Python package for working with hierarchical data. It provides APIs for specifying data models, reading and writing data to different storage backends, and representing data with Python object. Documentation of HDMF can be found at https://hdmf.readthedocs.io Latest Release ============== .. image:: https://badge.fury.io/py/hdmf.svg :target: https://badge.fury.io/py/hdmf .. image:: https://anaconda.org/conda-forge/hdmf/badges/version.svg :target: https://anaconda.org/conda-forge/hdmf Build Status ============ .. table:: +---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ | Linux | Windows and macOS | +=====================================================================+==================================================================================================+ | .. image:: https://circleci.com/gh/hdmf-dev/hdmf.svg?style=shield | .. image:: https://dev.azure.com/hdmf-dev/hdmf/_apis/build/status/hdmf-dev.hdmf?branchName=dev | | :target: https://circleci.com/gh/hdmf-dev/hdmf | :target: https://dev.azure.com/hdmf-dev/hdmf/_build/latest?definitionId=1&branchName=dev | +---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ **Conda** .. image:: https://circleci.com/gh/conda-forge/hdmf-feedstock.svg?style=shield :target: https://circleci.com/gh/conda-forge/hdmf-feedstock Overall Health ============== .. image:: https://github.com/hdmf-dev/hdmf/workflows/Run%20coverage/badge.svg :target: https://github.com/hdmf-dev/hdmf/actions?query=workflow%3A%22Run+coverage%22 .. image:: https://codecov.io/gh/hdmf-dev/hdmf/branch/dev/graph/badge.svg :target: https://codecov.io/gh/hdmf-dev/hdmf .. image:: https://requires.io/github/hdmf-dev/hdmf/requirements.svg?branch=dev :target: https://requires.io/github/hdmf-dev/hdmf/requirements/?branch=dev :alt: Requirements Status .. image:: https://readthedocs.org/projects/hdmf/badge/?version=latest :target: https://hdmf.readthedocs.io/en/latest/?badge=latest :alt: Documentation Status Installation ============ See the HDMF documentation for details http://hdmf.readthedocs.io/en/latest/getting_started.html#installation Code of Conduct =============== This project and everyone participating in it is governed by our `code of conduct guidelines <.github/CODE_OF_CONDUCT.md>`_. By participating, you are expected to uphold this code. Contributing ============ For details on how to contribute to HDMF see our `contribution guidelines `_. Citing HDMF =========== * **Manuscript:** .. code-block:: bibtex @INPROCEEDINGS{9005648, author={A. J. {Tritt} and O. {Rübel} and B. {Dichter} and R. {Ly} and D. {Kang} and E. F. {Chang} and L. M. {Frank} and K. {Bouchard}}, booktitle={2019 IEEE International Conference on Big Data (Big Data)}, title={HDMF: Hierarchical Data Modeling Framework for Modern Science Data Standards}, year={2019}, volume={}, number={}, pages={165-179}, doi={10.1109/BigData47090.2019.9005648}, note={}} * **RRID:** (Hierarchical Data Modeling Framework, RRID:SCR_021303) LICENSE ======= "hdmf" Copyright (c) 2017-2021, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: (1) Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. (2) Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. (3) Neither the name of the University of California, Lawrence Berkeley National Laboratory, U.S. Dept. of Energy nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. You are under no obligation whatsoever to provide any bug fixes, patches, or upgrades to the features, functionality or performance of the source code ("Enhancements") to anyone; however, if you choose to make your Enhancements available either publicly, or directly to Lawrence Berkeley National Laboratory, without imposing a separate written license agreement for such Enhancements, then you hereby grant the following license: a non-exclusive, royalty-free perpetual license to install, use, modify, prepare derivative works, incorporate into other computer software, distribute, and sublicense such enhancements or derivative works thereof, in binary and source code form. COPYRIGHT ========= "hdmf" Copyright (c) 2017-2021, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Innovation & Partnerships Office at IPO@lbl.gov. NOTICE. This Software was developed under funding from the U.S. Department of Energy and the U.S. Government consequently retains certain rights. As such, the U.S. Government has been granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, distribute copies to the public, prepare derivative works, and perform publicly and display publicly, and to permit other to do so. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/README.rst0000644000655200065520000001546400000000000014262 0ustar00circlecicircleci======================================== The Hierarchical Data Modeling Framework ======================================== The Hierarchical Data Modeling Framework, or *HDMF*, is a Python package for working with hierarchical data. It provides APIs for specifying data models, reading and writing data to different storage backends, and representing data with Python object. Documentation of HDMF can be found at https://hdmf.readthedocs.io Latest Release ============== .. image:: https://badge.fury.io/py/hdmf.svg :target: https://badge.fury.io/py/hdmf .. image:: https://anaconda.org/conda-forge/hdmf/badges/version.svg :target: https://anaconda.org/conda-forge/hdmf Build Status ============ .. table:: +---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ | Linux | Windows and macOS | +=====================================================================+==================================================================================================+ | .. image:: https://circleci.com/gh/hdmf-dev/hdmf.svg?style=shield | .. image:: https://dev.azure.com/hdmf-dev/hdmf/_apis/build/status/hdmf-dev.hdmf?branchName=dev | | :target: https://circleci.com/gh/hdmf-dev/hdmf | :target: https://dev.azure.com/hdmf-dev/hdmf/_build/latest?definitionId=1&branchName=dev | +---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ **Conda** .. image:: https://circleci.com/gh/conda-forge/hdmf-feedstock.svg?style=shield :target: https://circleci.com/gh/conda-forge/hdmf-feedstock Overall Health ============== .. image:: https://github.com/hdmf-dev/hdmf/workflows/Run%20coverage/badge.svg :target: https://github.com/hdmf-dev/hdmf/actions?query=workflow%3A%22Run+coverage%22 .. image:: https://codecov.io/gh/hdmf-dev/hdmf/branch/dev/graph/badge.svg :target: https://codecov.io/gh/hdmf-dev/hdmf .. image:: https://requires.io/github/hdmf-dev/hdmf/requirements.svg?branch=dev :target: https://requires.io/github/hdmf-dev/hdmf/requirements/?branch=dev :alt: Requirements Status .. image:: https://readthedocs.org/projects/hdmf/badge/?version=latest :target: https://hdmf.readthedocs.io/en/latest/?badge=latest :alt: Documentation Status Installation ============ See the HDMF documentation for details http://hdmf.readthedocs.io/en/latest/getting_started.html#installation Code of Conduct =============== This project and everyone participating in it is governed by our `code of conduct guidelines <.github/CODE_OF_CONDUCT.md>`_. By participating, you are expected to uphold this code. Contributing ============ For details on how to contribute to HDMF see our `contribution guidelines `_. Citing HDMF =========== * **Manuscript:** .. code-block:: bibtex @INPROCEEDINGS{9005648, author={A. J. {Tritt} and O. {Rübel} and B. {Dichter} and R. {Ly} and D. {Kang} and E. F. {Chang} and L. M. {Frank} and K. {Bouchard}}, booktitle={2019 IEEE International Conference on Big Data (Big Data)}, title={HDMF: Hierarchical Data Modeling Framework for Modern Science Data Standards}, year={2019}, volume={}, number={}, pages={165-179}, doi={10.1109/BigData47090.2019.9005648}, note={}} * **RRID:** (Hierarchical Data Modeling Framework, RRID:SCR_021303) LICENSE ======= "hdmf" Copyright (c) 2017-2021, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: (1) Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. (2) Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. (3) Neither the name of the University of California, Lawrence Berkeley National Laboratory, U.S. Dept. of Energy nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. You are under no obligation whatsoever to provide any bug fixes, patches, or upgrades to the features, functionality or performance of the source code ("Enhancements") to anyone; however, if you choose to make your Enhancements available either publicly, or directly to Lawrence Berkeley National Laboratory, without imposing a separate written license agreement for such Enhancements, then you hereby grant the following license: a non-exclusive, royalty-free perpetual license to install, use, modify, prepare derivative works, incorporate into other computer software, distribute, and sublicense such enhancements or derivative works thereof, in binary and source code form. COPYRIGHT ========= "hdmf" Copyright (c) 2017-2021, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Innovation & Partnerships Office at IPO@lbl.gov. NOTICE. This Software was developed under funding from the U.S. Department of Energy and the U.S. Government consequently retains certain rights. As such, the U.S. Government has been granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, distribute copies to the public, prepare derivative works, and perform publicly and display publicly, and to permit other to do so. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/license.txt0000644000655200065520000000454700000000000014756 0ustar00circlecicircleci“hdmf” Copyright (c) 2017-2021, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: (1) Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. (2) Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. (3) Neither the name of the University of California, Lawrence Berkeley National Laboratory, U.S. Dept. of Energy nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. You are under no obligation whatsoever to provide any bug fixes, patches, or upgrades to the features, functionality or performance of the source code ("Enhancements") to anyone; however, if you choose to make your Enhancements available either publicly, or directly to Lawrence Berkeley National Laboratory, without imposing a separate written license agreement for such Enhancements, then you hereby grant the following license: a non-exclusive, royalty-free perpetual license to install, use, modify, prepare derivative works, incorporate into other computer software, distribute, and sublicense such enhancements or derivative works thereof, in binary and source code form. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/requirements-dev.txt0000644000655200065520000000047100000000000016623 0ustar00circlecicircleci# pinned dependencies to reproduce an entire development environment to use HDMF, run HDMF tests, check code style, # compute coverage, and create test environments codecov==2.1.11 coverage==5.5 flake8==3.9.2 flake8-debugger==4.0.0 flake8-print==4.0.0 importlib-metadata==4.6.1 python-dateutil==2.8.2 tox==3.24.0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/requirements-doc.txt0000644000655200065520000000015000000000000016604 0ustar00circlecicircleci# dependencies to generate the documentation for HDMF sphinx sphinx_rtd_theme sphinx-gallery matplotlib ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/requirements-min.txt0000644000655200065520000000034600000000000016631 0ustar00circlecicircleci# minimum versions of package dependencies for installing HDMF h5py==2.10 # support for selection of datasets with list of indices added in 2.10 numpy==1.16 scipy==1.1 pandas==1.0.5 ruamel.yaml==0.16 jsonschema==2.6.0 setuptools ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/requirements.txt0000644000655200065520000000030000000000000016036 0ustar00circlecicircleci# pinned dependencies to reproduce an entire development environment to use HDMF h5py==3.3.0 numpy==1.21.1 scipy==1.7.0 pandas==1.3.1 ruamel.yaml==0.17.10 jsonschema==3.2.0 setuptools==57.4.0 ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1886272 hdmf-3.1.1/setup.cfg0000644000655200065520000000136500000000000014407 0ustar00circlecicircleci[versioneer] vcs = git versionfile_source = src/hdmf/_version.py versionfile_build = hdmf/_version.py style = pep440-pre tag_prefix = *.*.* [flake8] max-line-length = 120 max-complexity = 17 exclude = .git, .tox, __pycache__, build/, dist/, src/hdmf/common/hdmf-common-schema, docs/source/conf.py versioneer.py src/hdmf/_version.py src/hdmf/_due.py per-file-ignores = docs/gallery/*:E402,T001 docs/source/tutorials/*:E402,T001 src/hdmf/__init__.py:F401 src/hdmf/backends/__init__.py:F401 src/hdmf/backends/hdf5/__init__.py:F401 src/hdmf/build/__init__.py:F401 src/hdmf/spec/__init__.py:F401 src/hdmf/validate/__init__.py:F401 setup.py:T001 test.py:T001 [metadata] description-file = README.rst [egg_info] tag_build = tag_date = 0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/setup.py0000755000655200065520000000416700000000000014306 0ustar00circlecicircleci# -*- coding: utf-8 -*- from setuptools import setup, find_packages import versioneer with open('README.rst', 'r') as fp: readme = fp.read() pkgs = find_packages('src', exclude=['data']) print('found these packages:', pkgs) schema_dir = 'common/hdmf-common-schema/common' reqs = [ 'h5py>=2.10,<4', 'numpy>=1.16,<1.22', 'scipy>=1.1,<2', 'pandas>=1.0.5,<2', 'ruamel.yaml>=0.16,<1', 'jsonschema>=2.6.0,<4', 'setuptools', ] print(reqs) setup_args = { 'name': 'hdmf', 'version': versioneer.get_version(), 'cmdclass': versioneer.get_cmdclass(), 'description': 'A package for standardizing hierarchical object data', 'long_description': readme, 'long_description_content_type': 'text/x-rst; charset=UTF-8', 'author': 'Andrew Tritt', 'author_email': 'ajtritt@lbl.gov', 'url': 'https://github.com/hdmf-dev/hdmf', 'license': "BSD", 'install_requires': reqs, 'packages': pkgs, 'package_dir': {'': 'src'}, 'package_data': {'hdmf': ["%s/*.yaml" % schema_dir, "%s/*.json" % schema_dir]}, 'python_requires': '>=3.7', 'classifiers': [ "Programming Language :: Python", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "License :: OSI Approved :: BSD License", "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "Operating System :: Microsoft :: Windows", "Operating System :: MacOS", "Operating System :: Unix", "Topic :: Scientific/Engineering :: Medical Science Apps." ], 'keywords': 'python ' 'HDF ' 'HDF5 ' 'cross-platform ' 'open-data ' 'data-format ' 'open-source ' 'open-science ' 'reproducible-research ', 'zip_safe': False, 'entry_points': { 'console_scripts': ['validate_hdmf_spec=hdmf.testing.validate_spec:main'], } } if __name__ == '__main__': setup(**setup_args) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1686273 hdmf-3.1.1/src/0000755000655200065520000000000000000000000013350 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1886272 hdmf-3.1.1/src/hdmf/0000755000655200065520000000000000000000000014266 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/__init__.py0000644000655200065520000000313400000000000016400 0ustar00circlecicirclecifrom . import query # noqa: F401 from .container import Container, Data, DataRegion from .utils import docval, getargs from .region import ListSlicer from .backends.hdf5.h5_utils import H5RegionSlicer, H5Dataset @docval({'name': 'dataset', 'type': None, 'doc': 'the HDF5 dataset to slice'}, {'name': 'region', 'type': None, 'doc': 'the region reference to use to slice'}, is_method=False) def get_region_slicer(**kwargs): import warnings # noqa: E402 warnings.warn('get_region_slicer is deprecated and will be removed in HDMF 3.0.', DeprecationWarning) dataset, region = getargs('dataset', 'region', kwargs) if isinstance(dataset, (list, tuple, Data)): return ListSlicer(dataset, region) elif isinstance(dataset, H5Dataset): return H5RegionSlicer(dataset, region) return None from ._version import get_versions # noqa: E402 __version__ = get_versions()['version'] del get_versions from ._due import due, BibTeX # noqa: E402 due.cite(BibTeX(""" @INPROCEEDINGS{9005648, author={A. J. {Tritt} and O. {Rübel} and B. {Dichter} and R. {Ly} and D. {Kang} and E. F. {Chang} and L. M. {Frank} and K. {Bouchard}}, booktitle={2019 IEEE International Conference on Big Data (Big Data)}, title={HDMF: Hierarchical Data Modeling Framework for Modern Science Data Standards}, year={2019}, volume={}, number={}, pages={165-179}, doi={10.1109/BigData47090.2019.9005648}} """), description="HDMF: Hierarchical Data Modeling Framework for Modern Science Data Standards", # noqa: E501 path="hdmf/", version=__version__, cite_module=True) del due, BibTeX ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/_due.py0000644000655200065520000000374200000000000015562 0ustar00circlecicircleci# emacs: at the end of the file # ex: set sts=4 ts=4 sw=4 et: # ## ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### # """ Stub file for a guaranteed safe import of duecredit constructs: if duecredit is not available. To use it, place it into your project codebase to be imported, e.g. copy as cp stub.py /path/tomodule/module/due.py Note that it might be better to avoid naming it duecredit.py to avoid shadowing installed duecredit. Then use in your code as from .due import due, Doi, BibTeX, Text See https://github.com/duecredit/duecredit/blob/master/README.md for examples. Origin: Originally a part of the duecredit Copyright: 2015-2019 DueCredit developers License: BSD-2 """ __version__ = '0.0.8' class InactiveDueCreditCollector(object): """Just a stub at the Collector which would not do anything""" def _donothing(self, *args, **kwargs): """Perform no good and no bad""" pass def dcite(self, *args, **kwargs): """If I could cite I would""" def nondecorating_decorator(func): return func return nondecorating_decorator active = False activate = add = cite = dump = load = _donothing def __repr__(self): return self.__class__.__name__ + '()' def _donothing_func(*args, **kwargs): """Perform no good and no bad""" pass try: from duecredit import due, BibTeX, Doi, Url, Text if 'due' in locals() and not hasattr(due, 'cite'): raise RuntimeError( "Imported due lacks .cite. DueCredit is now disabled") except Exception as e: if not isinstance(e, ImportError): import logging logging.getLogger("duecredit").error( "Failed to import duecredit due to %s" % str(e)) # Initiate due stub due = InactiveDueCreditCollector() BibTeX = Doi = Url = Text = _donothing_func # Emacs mode definitions # Local Variables: # mode: python # py-indent-offset: 4 # tab-width: 4 # indent-tabs-mode: nil # End: ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1886272 hdmf-3.1.1/src/hdmf/_version.py0000644000655200065520000000076100000000000016470 0ustar00circlecicircleci # This file was generated by 'versioneer.py' (0.18) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. import json version_json = ''' { "date": "2021-07-29T16:55:01-0700", "dirty": false, "error": null, "full-revisionid": "df31c59aa396a9920077eb3970d966e9d0f7a75b", "version": "3.1.1" } ''' # END VERSION_JSON def get_versions(): return json.loads(version_json) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/array.py0000644000655200065520000001245400000000000015764 0ustar00circlecicirclecifrom abc import abstractmethod, ABCMeta import numpy as np class Array: def __init__(self, data): self.__data = data if hasattr(data, 'dtype'): self.dtype = data.dtype else: tmp = data while isinstance(tmp, (list, tuple)): tmp = tmp[0] self.dtype = type(tmp) @property def data(self): return self.__data def __len__(self): return len(self.__data) def get_data(self): return self.__data def __getidx__(self, arg): return self.__data[arg] def __sliceiter(self, arg): return (x for x in range(*arg.indices(len(self)))) def __getitem__(self, arg): if isinstance(arg, list): idx = list() for i in arg: if isinstance(i, slice): idx.extend(x for x in self.__sliceiter(i)) else: idx.append(i) return np.fromiter((self.__getidx__(x) for x in idx), dtype=self.dtype) elif isinstance(arg, slice): return np.fromiter((self.__getidx__(x) for x in self.__sliceiter(arg)), dtype=self.dtype) elif isinstance(arg, tuple): return (self.__getidx__(arg[0]), self.__getidx__(arg[1])) else: return self.__getidx__(arg) class AbstractSortedArray(Array, metaclass=ABCMeta): ''' An abstract class for representing sorted array ''' @abstractmethod def find_point(self, val): pass def get_data(self): return self def __lower(self, other): ins = self.find_point(other) return ins def __upper(self, other): ins = self.__lower(other) while self[ins] == other: ins += 1 return ins def __lt__(self, other): ins = self.__lower(other) return slice(0, ins) def __le__(self, other): ins = self.__upper(other) return slice(0, ins) def __gt__(self, other): ins = self.__upper(other) return slice(ins, len(self)) def __ge__(self, other): ins = self.__lower(other) return slice(ins, len(self)) @staticmethod def __sort(a): if isinstance(a, tuple): return a[0] else: return a def __eq__(self, other): if isinstance(other, list): ret = list() for i in other: eq = self == i ret.append(eq) ret = sorted(ret, key=self.__sort) tmp = list() for i in range(1, len(ret)): a, b = ret[i - 1], ret[i] if isinstance(a, tuple): if isinstance(b, tuple): if a[1] >= b[0]: b[0] = a[0] else: tmp.append(slice(*a)) else: if b > a[1]: tmp.append(slice(*a)) elif b == a[1]: a[1] == b + 1 else: ret[i] = a else: if isinstance(b, tuple): if a < b[0]: tmp.append(a) else: if b - a == 1: ret[i] = (a, b) else: tmp.append(a) if isinstance(ret[-1], tuple): tmp.append(slice(*ret[-1])) else: tmp.append(ret[-1]) ret = tmp return ret elif isinstance(other, tuple): ge = self >= other[0] ge = ge.start lt = self < other[1] lt = lt.stop if ge == lt: return ge else: return slice(ge, lt) else: lower = self.__lower(other) upper = self.__upper(other) d = upper - lower if d == 1: return lower elif d == 0: return None else: return slice(lower, upper) def __ne__(self, other): eq = self == other if isinstance(eq, tuple): return [slice(0, eq[0]), slice(eq[1], len(self))] else: return [slice(0, eq), slice(eq + 1, len(self))] class SortedArray(AbstractSortedArray): ''' A class for wrapping sorted arrays. This class overrides <,>,<=,>=,==, and != to leverage the sorted content for efficiency. ''' def __init__(self, array): super().__init__(array) def find_point(self, val): return np.searchsorted(self.data, val) class LinSpace(SortedArray): def __init__(self, start, stop, step): self.start = start self.stop = stop self.step = step self.dtype = float if any(isinstance(s, float) for s in (start, stop, step)) else int self.__len = int((stop - start) / step) def __len__(self): return self.__len def find_point(self, val): nsteps = (val - self.start) / self.step fl = int(nsteps) if fl == nsteps: return int(fl) else: return int(fl + 1) def __getidx__(self, arg): return self.start + self.step * arg ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1766272 hdmf-3.1.1/src/hdmf/backends/0000755000655200065520000000000000000000000016040 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/backends/__init__.py0000644000655200065520000000002300000000000020144 0ustar00circlecicirclecifrom . import hdf5 ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1766272 hdmf-3.1.1/src/hdmf/backends/hdf5/0000755000655200065520000000000000000000000016666 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/backends/hdf5/__init__.py0000644000655200065520000000020700000000000020776 0ustar00circlecicirclecifrom . import h5_utils, h5tools from .h5_utils import H5RegionSlicer, H5DataIO from .h5tools import HDF5IO, H5SpecWriter, H5SpecReader ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/backends/hdf5/h5_utils.py0000644000655200065520000005036700000000000021007 0ustar00circlecicircleciimport json import os import warnings from abc import ABCMeta, abstractmethod from collections.abc import Iterable from copy import copy import numpy as np from h5py import Group, Dataset, RegionReference, Reference, special_dtype from h5py import filters as h5py_filters from ...array import Array from ...data_utils import DataIO, AbstractDataChunkIterator from ...query import HDMFDataset, ReferenceResolver, ContainerResolver, BuilderResolver from ...region import RegionSlicer from ...spec import SpecWriter, SpecReader from ...utils import docval, getargs, popargs, call_docval_func, get_docval class H5Dataset(HDMFDataset): @docval({'name': 'dataset', 'type': (Dataset, Array), 'doc': 'the HDF5 file lazily evaluate'}, {'name': 'io', 'type': 'HDF5IO', 'doc': 'the IO object that was used to read the underlying dataset'}) def __init__(self, **kwargs): self.__io = popargs('io', kwargs) call_docval_func(super().__init__, kwargs) @property def io(self): return self.__io @property def regionref(self): return self.dataset.regionref @property def ref(self): return self.dataset.ref @property def shape(self): return self.dataset.shape class DatasetOfReferences(H5Dataset, ReferenceResolver, metaclass=ABCMeta): """ An extension of the base ReferenceResolver class to add more abstract methods for subclasses that will read HDF5 references """ @abstractmethod def get_object(self, h5obj): """ A class that maps an HDF5 object to a Builder or Container """ pass def invert(self): """ Return an object that defers reference resolution but in the opposite direction. """ if not hasattr(self, '__inverted'): cls = self.get_inverse_class() docval = get_docval(cls.__init__) kwargs = dict() for arg in docval: kwargs[arg['name']] = getattr(self, arg['name']) self.__inverted = cls(**kwargs) return self.__inverted def _get_ref(self, ref): return self.get_object(self.dataset.file[ref]) def __iter__(self): for ref in super().__iter__(): yield self._get_ref(ref) def __next__(self): return self._get_ref(super().__next__()) class BuilderResolverMixin(BuilderResolver): """ A mixin for adding to HDF5 reference-resolving types the get_object method that returns Builders """ def get_object(self, h5obj): """ A class that maps an HDF5 object to a Builder """ return self.io.get_builder(h5obj) class ContainerResolverMixin(ContainerResolver): """ A mixin for adding to HDF5 reference-resolving types the get_object method that returns Containers """ def get_object(self, h5obj): """ A class that maps an HDF5 object to a Container """ return self.io.get_container(h5obj) class AbstractH5TableDataset(DatasetOfReferences): @docval({'name': 'dataset', 'type': (Dataset, Array), 'doc': 'the HDF5 file lazily evaluate'}, {'name': 'io', 'type': 'HDF5IO', 'doc': 'the IO object that was used to read the underlying dataset'}, {'name': 'types', 'type': (list, tuple), 'doc': 'the IO object that was used to read the underlying dataset'}) def __init__(self, **kwargs): types = popargs('types', kwargs) call_docval_func(super().__init__, kwargs) self.__refgetters = dict() for i, t in enumerate(types): if t is RegionReference: self.__refgetters[i] = self.__get_regref elif t is Reference: self.__refgetters[i] = self._get_ref elif t is str: # we need this for when we read compound data types # that have unicode sub-dtypes since h5py does not # store UTF-8 in compound dtypes self.__refgetters[i] = self._get_utf self.__types = types tmp = list() for i in range(len(self.dataset.dtype)): sub = self.dataset.dtype[i] if sub.metadata: if 'vlen' in sub.metadata: t = sub.metadata['vlen'] if t is str: tmp.append('utf') elif t is bytes: tmp.append('ascii') elif 'ref' in sub.metadata: t = sub.metadata['ref'] if t is Reference: tmp.append('object') elif t is RegionReference: tmp.append('region') else: tmp.append(sub.type.__name__) self.__dtype = tmp @property def types(self): return self.__types @property def dtype(self): return self.__dtype def __getitem__(self, arg): rows = copy(super().__getitem__(arg)) if np.issubdtype(type(arg), np.integer): self.__swap_refs(rows) else: for row in rows: self.__swap_refs(row) return rows def __swap_refs(self, row): for i in self.__refgetters: getref = self.__refgetters[i] row[i] = getref(row[i]) def _get_utf(self, string): """ Decode a dataset element to unicode """ return string.decode('utf-8') if isinstance(string, bytes) else string def __get_regref(self, ref): obj = self._get_ref(ref) return obj[ref] def resolve(self, manager): return self[0:len(self)] def __iter__(self): for i in range(len(self)): yield self[i] class AbstractH5ReferenceDataset(DatasetOfReferences): def __getitem__(self, arg): ref = super().__getitem__(arg) if isinstance(ref, np.ndarray): return [self._get_ref(x) for x in ref] else: return self._get_ref(ref) @property def dtype(self): return 'object' class AbstractH5RegionDataset(AbstractH5ReferenceDataset): def __getitem__(self, arg): obj = super().__getitem__(arg) ref = self.dataset[arg] return obj[ref] @property def dtype(self): return 'region' class ContainerH5TableDataset(ContainerResolverMixin, AbstractH5TableDataset): """ A reference-resolving dataset for resolving references inside tables (i.e. compound dtypes) that returns resolved references as Containers """ @classmethod def get_inverse_class(cls): return BuilderH5TableDataset class BuilderH5TableDataset(BuilderResolverMixin, AbstractH5TableDataset): """ A reference-resolving dataset for resolving references inside tables (i.e. compound dtypes) that returns resolved references as Builders """ @classmethod def get_inverse_class(cls): return ContainerH5TableDataset class ContainerH5ReferenceDataset(ContainerResolverMixin, AbstractH5ReferenceDataset): """ A reference-resolving dataset for resolving object references that returns resolved references as Containers """ @classmethod def get_inverse_class(cls): return BuilderH5ReferenceDataset class BuilderH5ReferenceDataset(BuilderResolverMixin, AbstractH5ReferenceDataset): """ A reference-resolving dataset for resolving object references that returns resolved references as Builders """ @classmethod def get_inverse_class(cls): return ContainerH5ReferenceDataset class ContainerH5RegionDataset(ContainerResolverMixin, AbstractH5RegionDataset): """ A reference-resolving dataset for resolving region references that returns resolved references as Containers """ @classmethod def get_inverse_class(cls): return BuilderH5RegionDataset class BuilderH5RegionDataset(BuilderResolverMixin, AbstractH5RegionDataset): """ A reference-resolving dataset for resolving region references that returns resolved references as Builders """ @classmethod def get_inverse_class(cls): return ContainerH5RegionDataset class H5SpecWriter(SpecWriter): __str_type = special_dtype(vlen=str) @docval({'name': 'group', 'type': Group, 'doc': 'the HDF5 file to write specs to'}) def __init__(self, **kwargs): self.__group = getargs('group', kwargs) @staticmethod def stringify(spec): ''' Converts a spec into a JSON string to write to a dataset ''' return json.dumps(spec, separators=(',', ':')) def __write(self, d, name): data = self.stringify(d) # create spec group if it does not exist. otherwise, do not overwrite existing spec dset = self.__group.create_dataset(name, shape=tuple(), data=data, dtype=self.__str_type) return dset def write_spec(self, spec, path): return self.__write(spec, path) def write_namespace(self, namespace, path): return self.__write({'namespaces': [namespace]}, path) class H5SpecReader(SpecReader): """Class that reads cached JSON-formatted namespace and spec data from an HDF5 group.""" @docval({'name': 'group', 'type': Group, 'doc': 'the HDF5 group to read specs from'}) def __init__(self, **kwargs): self.__group = getargs('group', kwargs) super_kwargs = {'source': "%s:%s" % (os.path.abspath(self.__group.file.name), self.__group.name)} call_docval_func(super().__init__, super_kwargs) self.__cache = None def __read(self, path): s = self.__group[path][()] if isinstance(s, np.ndarray) and s.shape == (1,): # unpack scalar spec dataset s = s[0] if isinstance(s, bytes): s = s.decode('UTF-8') d = json.loads(s) return d def read_spec(self, spec_path): return self.__read(spec_path) def read_namespace(self, ns_path): if self.__cache is None: self.__cache = self.__read(ns_path) ret = self.__cache['namespaces'] return ret class H5RegionSlicer(RegionSlicer): @docval({'name': 'dataset', 'type': (Dataset, H5Dataset), 'doc': 'the HDF5 dataset to slice'}, {'name': 'region', 'type': RegionReference, 'doc': 'the region reference to use to slice'}) def __init__(self, **kwargs): self.__dataset = getargs('dataset', kwargs) self.__regref = getargs('region', kwargs) self.__len = self.__dataset.regionref.selection(self.__regref)[0] self.__region = None def __read_region(self): if self.__region is None: self.__region = self.__dataset[self.__regref] def __getitem__(self, idx): self.__read_region() return self.__region[idx] def __len__(self): return self.__len class H5DataIO(DataIO): """ Wrap data arrays for write via HDF5IO to customize I/O behavior, such as compression and chunking for data arrays. """ @docval({'name': 'data', 'type': (np.ndarray, list, tuple, Dataset, Iterable), 'doc': 'the data to be written. NOTE: If an h5py.Dataset is used, all other settings but link_data' + ' will be ignored as the dataset will either be linked to or copied as is in H5DataIO.', 'default': None}, {'name': 'maxshape', 'type': tuple, 'doc': 'Dataset will be resizable up to this shape (Tuple). Automatically enables chunking.' + 'Use None for the axes you want to be unlimited.', 'default': None}, {'name': 'chunks', 'type': (bool, tuple), 'doc': 'Chunk shape or True to enable auto-chunking', 'default': None}, {'name': 'compression', 'type': (str, bool, int), 'doc': 'Compression strategy. If a bool is given, then gzip compression will be used by default.' + 'http://docs.h5py.org/en/latest/high/dataset.html#dataset-compression', 'default': None}, {'name': 'compression_opts', 'type': (int, tuple), 'doc': 'Parameter for compression filter', 'default': None}, {'name': 'fillvalue', 'type': None, 'doc': 'Value to be returned when reading uninitialized parts of the dataset', 'default': None}, {'name': 'shuffle', 'type': bool, 'doc': 'Enable shuffle I/O filter. http://docs.h5py.org/en/latest/high/dataset.html#dataset-shuffle', 'default': None}, {'name': 'fletcher32', 'type': bool, 'doc': 'Enable fletcher32 checksum. http://docs.h5py.org/en/latest/high/dataset.html#dataset-fletcher32', 'default': None}, {'name': 'link_data', 'type': bool, 'doc': 'If data is an h5py.Dataset should it be linked to or copied. NOTE: This parameter is only ' + 'allowed if data is an h5py.Dataset', 'default': False}, {'name': 'allow_plugin_filters', 'type': bool, 'doc': 'Enable passing dynamically loaded filters as compression parameter', 'default': False} ) def __init__(self, **kwargs): # Get the list of I/O options that user has passed in ioarg_names = [name for name in kwargs.keys() if name not in ['data', 'link_data', 'allow_plugin_filters']] # Remove the ioargs from kwargs ioarg_values = [popargs(argname, kwargs) for argname in ioarg_names] # Consume link_data parameter self.__link_data = popargs('link_data', kwargs) # Consume allow_plugin_filters parameter self.__allow_plugin_filters = popargs('allow_plugin_filters', kwargs) # Check for possible collision with other parameters if not isinstance(getargs('data', kwargs), Dataset) and self.__link_data: self.__link_data = False warnings.warn('link_data parameter in H5DataIO will be ignored') # Call the super constructor and consume the data parameter call_docval_func(super().__init__, kwargs) # Construct the dict with the io args, ignoring all options that were set to None self.__iosettings = {k: v for k, v in zip(ioarg_names, ioarg_values) if v is not None} # Set io_properties for DataChunkIterators if isinstance(self.data, AbstractDataChunkIterator): # Define the chunking options if the user has not set them explicitly. if 'chunks' not in self.__iosettings and self.data.recommended_chunk_shape() is not None: self.__iosettings['chunks'] = self.data.recommended_chunk_shape() # Define the maxshape of the data if not provided by the user if 'maxshape' not in self.__iosettings: self.__iosettings['maxshape'] = self.data.maxshape # Make default settings when compression set to bool (True/False) if isinstance(self.__iosettings.get('compression', None), bool): if self.__iosettings['compression']: self.__iosettings['compression'] = 'gzip' else: self.__iosettings.pop('compression', None) if 'compression_opts' in self.__iosettings: warnings.warn('Compression disabled by compression=False setting. ' + 'compression_opts parameter will, therefore, be ignored.') self.__iosettings.pop('compression_opts', None) # Validate the compression options used self._check_compression_options() # Confirm that the compressor is supported by h5py if not self.filter_available(self.__iosettings.get('compression', None), self.__allow_plugin_filters): msg = "%s compression may not be supported by this version of h5py." % str(self.__iosettings['compression']) if not self.__allow_plugin_filters: msg += " Set `allow_plugin_filters=True` to enable the use of dynamically-loaded plugin filters." raise ValueError(msg) # Check possible parameter collisions if isinstance(self.data, Dataset): for k in self.__iosettings.keys(): warnings.warn("%s in H5DataIO will be ignored with H5DataIO.data being an HDF5 dataset" % k) def get_io_params(self): """ Returns a dict with the I/O parameters specifiedin in this DataIO. """ ret = dict(self.__iosettings) ret['link_data'] = self.__link_data return ret def _check_compression_options(self): """ Internal helper function used to check if compression options are compliant with the compression filter used. :raises ValueError: If incompatible options are detected """ if 'compression' in self.__iosettings: if 'compression_opts' in self.__iosettings: if self.__iosettings['compression'] == 'gzip': if self.__iosettings['compression_opts'] not in range(10): raise ValueError("GZIP compression_opts setting must be an integer from 0-9, " "not " + str(self.__iosettings['compression_opts'])) elif self.__iosettings['compression'] == 'lzf': if self.__iosettings['compression_opts'] is not None: raise ValueError("LZF compression filter accepts no compression_opts") elif self.__iosettings['compression'] == 'szip': szip_opts_error = False # Check that we have a tuple szip_opts_error |= not isinstance(self.__iosettings['compression_opts'], tuple) # Check that we have a tuple of the right length and correct settings if not szip_opts_error: try: szmethod, szpix = self.__iosettings['compression_opts'] szip_opts_error |= (szmethod not in ('ec', 'nn')) szip_opts_error |= (not (0 < szpix <= 32 and szpix % 2 == 0)) except ValueError: # ValueError is raised if tuple does not have the right length to unpack szip_opts_error = True if szip_opts_error: raise ValueError("SZIP compression filter compression_opts" " must be a 2-tuple ('ec'|'nn', even integer 0-32).") # Warn if compressor other than gzip is being used if self.__iosettings['compression'] not in ['gzip', h5py_filters.h5z.FILTER_DEFLATE]: warnings.warn(str(self.__iosettings['compression']) + " compression may not be available " "on all installations of HDF5. Use of gzip is recommended to ensure portability of " "the generated HDF5 files.") @staticmethod def filter_available(filter, allow_plugin_filters): """ Check if a given I/O filter is available :param filter: String with the name of the filter, e.g., gzip, szip etc. int with the registered filter ID, e.g. 307 :type filter: String, int :param allow_plugin_filters: bool indicating whether the given filter can be dynamically loaded :return: bool indicating wether the given filter is available """ if filter is not None: if filter in h5py_filters.encode: return True elif allow_plugin_filters is True: if type(filter) == int: if h5py_filters.h5z.filter_avail(filter): filter_info = h5py_filters.h5z.get_filter_info(filter) if filter_info == (h5py_filters.h5z.FILTER_CONFIG_DECODE_ENABLED + h5py_filters.h5z.FILTER_CONFIG_ENCODE_ENABLED): return True return False else: return True @property def link_data(self): return self.__link_data @property def io_settings(self): return self.__iosettings @property def valid(self): if isinstance(self.data, Dataset) and not self.data.id.valid: return False return super().valid ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/backends/hdf5/h5tools.py0000644000655200065520000021516000000000000020642 0ustar00circlecicircleciimport logging import os.path import warnings from collections import deque from functools import partial from pathlib import Path import numpy as np import h5py from h5py import File, Group, Dataset, special_dtype, SoftLink, ExternalLink, Reference, RegionReference, check_dtype from .h5_utils import (BuilderH5ReferenceDataset, BuilderH5RegionDataset, BuilderH5TableDataset, H5DataIO, H5SpecReader, H5SpecWriter) from ..io import HDMFIO, UnsupportedOperation from ..warnings import BrokenLinkWarning from ...build import (Builder, GroupBuilder, DatasetBuilder, LinkBuilder, BuildManager, RegionBuilder, ReferenceBuilder, TypeMap, ObjectMapper) from ...container import Container from ...data_utils import AbstractDataChunkIterator from ...spec import RefSpec, DtypeSpec, NamespaceCatalog, GroupSpec, NamespaceBuilder from ...utils import docval, getargs, popargs, call_docval_func, get_data_shape, fmt_docval_args, get_docval, StrDataset ROOT_NAME = 'root' SPEC_LOC_ATTR = '.specloc' H5_TEXT = special_dtype(vlen=str) H5_BINARY = special_dtype(vlen=bytes) H5_REF = special_dtype(ref=Reference) H5_REGREF = special_dtype(ref=RegionReference) H5PY_3 = h5py.__version__.startswith('3') class HDF5IO(HDMFIO): __ns_spec_path = 'namespace' # path to the namespace dataset within a namespace group @docval({'name': 'path', 'type': (str, Path), 'doc': 'the path to the HDF5 file'}, {'name': 'manager', 'type': (TypeMap, BuildManager), 'doc': 'the BuildManager or a TypeMap to construct a BuildManager to use for I/O', 'default': None}, {'name': 'mode', 'type': str, 'doc': ('the mode to open the HDF5 file with, one of ("w", "r", "r+", "a", "w-", "x"). ' 'See `h5py.File `_ for ' 'more details.')}, {'name': 'comm', 'type': 'Intracomm', 'doc': 'the MPI communicator to use for parallel I/O', 'default': None}, {'name': 'file', 'type': File, 'doc': 'a pre-existing h5py.File object', 'default': None}, {'name': 'driver', 'type': str, 'doc': 'driver for h5py to use when opening HDF5 file', 'default': None}) def __init__(self, **kwargs): """Open an HDF5 file for IO. """ self.logger = logging.getLogger('%s.%s' % (self.__class__.__module__, self.__class__.__qualname__)) path, manager, mode, comm, file_obj, driver = popargs('path', 'manager', 'mode', 'comm', 'file', 'driver', kwargs) if isinstance(path, Path): path = str(path) if file_obj is not None and os.path.abspath(file_obj.filename) != os.path.abspath(path): msg = 'You argued %s as this object\'s path, ' % path msg += 'but supplied a file with filename: %s' % file_obj.filename raise ValueError(msg) if file_obj is None and not os.path.exists(path) and (mode == 'r' or mode == 'r+') and driver != 'ros3': msg = "Unable to open file %s in '%s' mode. File does not exist." % (path, mode) raise UnsupportedOperation(msg) if file_obj is None and os.path.exists(path) and (mode == 'w-' or mode == 'x'): msg = "Unable to open file %s in '%s' mode. File already exists." % (path, mode) raise UnsupportedOperation(msg) if manager is None: manager = BuildManager(TypeMap(NamespaceCatalog())) elif isinstance(manager, TypeMap): manager = BuildManager(manager) self.__driver = driver self.__comm = comm self.__mode = mode self.__file = file_obj super().__init__(manager, source=path) self.__built = dict() # keep track of each builder for each dataset/group/link for each file self.__read = dict() # keep track of which files have been read. Key is the filename value is the builder self.__ref_queue = deque() # a queue of the references that need to be added self.__dci_queue = deque() # a queue of DataChunkIterators that need to be exhausted ObjectMapper.no_convert(Dataset) self._written_builders = dict() # keep track of which builders were written (or read) by this IO object self.__open_links = [] # keep track of other files opened from links in this file @property def comm(self): """The MPI communicator to use for parallel I/O.""" return self.__comm @property def _file(self): return self.__file @property def driver(self): return self.__driver @staticmethod def __resolve_file_obj(path, file_obj, driver): if isinstance(path, Path): path = str(path) if path is None and file_obj is None: raise ValueError("Either the 'path' or 'file' argument must be supplied.") if path is not None and file_obj is not None: # consistency check if os.path.abspath(file_obj.filename) != os.path.abspath(path): msg = ("You argued '%s' as this object's path, but supplied a file with filename: %s" % (path, file_obj.filename)) raise ValueError(msg) if file_obj is None: file_kwargs = dict() if driver is not None: file_kwargs.update(driver=driver) file_obj = File(path, 'r', **file_kwargs) return file_obj @classmethod @docval({'name': 'namespace_catalog', 'type': (NamespaceCatalog, TypeMap), 'doc': 'the NamespaceCatalog or TypeMap to load namespaces into'}, {'name': 'path', 'type': (str, Path), 'doc': 'the path to the HDF5 file', 'default': None}, {'name': 'namespaces', 'type': list, 'doc': 'the namespaces to load', 'default': None}, {'name': 'file', 'type': File, 'doc': 'a pre-existing h5py.File object', 'default': None}, {'name': 'driver', 'type': str, 'doc': 'driver for h5py to use when opening HDF5 file', 'default': None}, returns=("dict mapping the names of the loaded namespaces to a dict mapping included namespace names and " "the included data types"), rtype=dict) def load_namespaces(cls, **kwargs): """Load cached namespaces from a file. If `file` is not supplied, then an :py:class:`h5py.File` object will be opened for the given `path`, the namespaces will be read, and the File object will be closed. If `file` is supplied, then the given File object will be read from and not closed. :raises ValueError: if both `path` and `file` are supplied but `path` is not the same as the path of `file`. """ namespace_catalog, path, namespaces, file_obj, driver = popargs( 'namespace_catalog', 'path', 'namespaces', 'file', 'driver', kwargs) open_file_obj = cls.__resolve_file_obj(path, file_obj, driver) if file_obj is None: # need to close the file object that we just opened with open_file_obj: return cls.__load_namespaces(namespace_catalog, namespaces, open_file_obj) return cls.__load_namespaces(namespace_catalog, namespaces, open_file_obj) @classmethod def __load_namespaces(cls, namespace_catalog, namespaces, file_obj): d = {} if not cls.__check_specloc(file_obj): return d namespace_versions = cls.__get_namespaces(file_obj) spec_group = file_obj[file_obj.attrs[SPEC_LOC_ATTR]] if namespaces is None: namespaces = list(spec_group.keys()) readers = dict() deps = dict() for ns in namespaces: latest_version = namespace_versions[ns] ns_group = spec_group[ns][latest_version] reader = H5SpecReader(ns_group) readers[ns] = reader # for each namespace in the 'namespace' dataset, track all included namespaces (dependencies) for spec_ns in reader.read_namespace(cls.__ns_spec_path): deps[ns] = list() for s in spec_ns['schema']: dep = s.get('namespace') if dep is not None: deps[ns].append(dep) order = cls._order_deps(deps) for ns in order: reader = readers[ns] d.update(namespace_catalog.load_namespaces(cls.__ns_spec_path, reader=reader)) return d @classmethod def __check_specloc(cls, file_obj): if SPEC_LOC_ATTR not in file_obj.attrs: # this occurs in legacy files msg = "No cached namespaces found in %s" % file_obj.filename warnings.warn(msg) return False return True @classmethod @docval({'name': 'path', 'type': (str, Path), 'doc': 'the path to the HDF5 file', 'default': None}, {'name': 'file', 'type': File, 'doc': 'a pre-existing h5py.File object', 'default': None}, {'name': 'driver', 'type': str, 'doc': 'driver for h5py to use when opening HDF5 file', 'default': None}, returns="dict mapping names to versions of the namespaces in the file", rtype=dict) def get_namespaces(cls, **kwargs): """Get the names and versions of the cached namespaces from a file. If `file` is not supplied, then an :py:class:`h5py.File` object will be opened for the given `path`, the namespaces will be read, and the File object will be closed. If `file` is supplied, then the given File object will be read from and not closed. If there are multiple versions of a namespace cached in the file, then only the latest one (using alphanumeric ordering) is returned. This is the version of the namespace that is loaded by HDF5IO.load_namespaces(...). :raises ValueError: if both `path` and `file` are supplied but `path` is not the same as the path of `file`. """ path, file_obj, driver = popargs('path', 'file', 'driver', kwargs) open_file_obj = cls.__resolve_file_obj(path, file_obj, driver) if file_obj is None: # need to close the file object that we just opened with open_file_obj: return cls.__get_namespaces(open_file_obj) return cls.__get_namespaces(open_file_obj) @classmethod def __get_namespaces(cls, file_obj): """Return a dict mapping namespace name to version string for the latest version of that namespace in the file. If there are multiple versions of a namespace cached in the file, then only the latest one (using alphanumeric ordering) is returned. This is the version of the namespace that is loaded by HDF5IO.load_namespaces(...). """ used_version_names = dict() if not cls.__check_specloc(file_obj): return used_version_names spec_group = file_obj[file_obj.attrs[SPEC_LOC_ATTR]] namespaces = list(spec_group.keys()) for ns in namespaces: ns_group = spec_group[ns] # NOTE: by default, objects within groups are iterated in alphanumeric order version_names = list(ns_group.keys()) if len(version_names) > 1: # prior to HDMF 1.6.1, extensions without a version were written under the group name "unversioned" # make sure that if there is another group representing a newer version, that is read instead if 'unversioned' in version_names: version_names.remove('unversioned') if len(version_names) > 1: # as of HDMF 1.6.1, extensions without a version are written under the group name "None" # make sure that if there is another group representing a newer version, that is read instead if 'None' in version_names: version_names.remove('None') used_version_names[ns] = version_names[-1] # save the largest in alphanumeric order return used_version_names @classmethod def _order_deps(cls, deps): """ Order namespaces according to dependency for loading into a NamespaceCatalog Args: deps (dict): a dictionary that maps a namespace name to a list of name of the namespaces on which the namespace is directly dependent Example: {'a': ['b', 'c'], 'b': ['d'], 'c': ['d'], 'd': []} Expected output: ['d', 'b', 'c', 'a'] """ order = list() keys = list(deps.keys()) deps = dict(deps) for k in keys: if k in deps: cls.__order_deps_aux(order, deps, k) return order @classmethod def __order_deps_aux(cls, order, deps, key): """ A recursive helper function for _order_deps """ if key not in deps: return subdeps = deps.pop(key) for subk in subdeps: cls.__order_deps_aux(order, deps, subk) order.append(key) @classmethod def __convert_namespace(cls, ns_catalog, namespace): ns = ns_catalog.get_namespace(namespace) builder = NamespaceBuilder(ns.doc, ns.name, full_name=ns.full_name, version=ns.version, author=ns.author, contact=ns.contact) for elem in ns.schema: if 'namespace' in elem: inc_ns = elem['namespace'] builder.include_namespace(inc_ns) else: source = elem['source'] for dt in ns_catalog.get_types(source): spec = ns_catalog.get_spec(namespace, dt) if spec.parent is not None: continue h5_source = cls.__get_name(source) spec = cls.__copy_spec(spec) builder.add_spec(h5_source, spec) return builder @classmethod def __get_name(cls, path): return os.path.splitext(path)[0] @classmethod def __copy_spec(cls, spec): kwargs = dict() kwargs['attributes'] = cls.__get_new_specs(spec.attributes, spec) to_copy = ['doc', 'name', 'default_name', 'linkable', 'quantity', spec.inc_key(), spec.def_key()] if isinstance(spec, GroupSpec): kwargs['datasets'] = cls.__get_new_specs(spec.datasets, spec) kwargs['groups'] = cls.__get_new_specs(spec.groups, spec) kwargs['links'] = cls.__get_new_specs(spec.links, spec) else: to_copy.append('dtype') to_copy.append('shape') to_copy.append('dims') for key in to_copy: val = getattr(spec, key) if val is not None: kwargs[key] = val ret = spec.build_spec(kwargs) return ret @classmethod def __get_new_specs(cls, subspecs, spec): ret = list() for subspec in subspecs: if not spec.is_inherited_spec(subspec) or spec.is_overridden_spec(subspec): ret.append(subspec) return ret @classmethod @docval({'name': 'source_filename', 'type': str, 'doc': 'the path to the HDF5 file to copy'}, {'name': 'dest_filename', 'type': str, 'doc': 'the name of the destination file'}, {'name': 'expand_external', 'type': bool, 'doc': 'expand external links into new objects', 'default': True}, {'name': 'expand_refs', 'type': bool, 'doc': 'copy objects which are pointed to by reference', 'default': False}, {'name': 'expand_soft', 'type': bool, 'doc': 'expand soft links into new objects', 'default': False} ) def copy_file(self, **kwargs): """ Convenience function to copy an HDF5 file while allowing external links to be resolved. .. warning:: As of HDMF 2.0, this method is no longer supported and may be removed in a future version. Please use the export method or h5py.File.copy method instead. .. note:: The source file will be opened in 'r' mode and the destination file will be opened in 'w' mode using h5py. To avoid possible collisions, care should be taken that, e.g., the source file is not opened already when calling this function. """ warnings.warn("The copy_file class method is no longer supported and may be removed in a future version of " "HDMF. Please use the export method or h5py.File.copy method instead.", DeprecationWarning) source_filename, dest_filename, expand_external, expand_refs, expand_soft = getargs('source_filename', 'dest_filename', 'expand_external', 'expand_refs', 'expand_soft', kwargs) source_file = File(source_filename, 'r') dest_file = File(dest_filename, 'w') for objname in source_file["/"].keys(): source_file.copy(source=objname, dest=dest_file, name=objname, expand_external=expand_external, expand_refs=expand_refs, expand_soft=expand_soft, shallow=False, without_attrs=False, ) for objname in source_file['/'].attrs: dest_file['/'].attrs[objname] = source_file['/'].attrs[objname] source_file.close() dest_file.close() @docval({'name': 'container', 'type': Container, 'doc': 'the Container object to write'}, {'name': 'cache_spec', 'type': bool, 'doc': ('If True (default), cache specification to file (highly recommended). If False, do not cache ' 'specification to file. The appropriate specification will then need to be loaded prior to ' 'reading the file.'), 'default': True}, {'name': 'link_data', 'type': bool, 'doc': 'If True (default), create external links to HDF5 Datasets. If False, copy HDF5 Datasets.', 'default': True}, {'name': 'exhaust_dci', 'type': bool, 'doc': 'If True (default), exhaust DataChunkIterators one at a time. If False, exhaust them concurrently.', 'default': True}) def write(self, **kwargs): """Write the container to an HDF5 file.""" if self.__mode == 'r': raise UnsupportedOperation(("Cannot write to file %s in mode '%s'. " "Please use mode 'r+', 'w', 'w-', 'x', or 'a'") % (self.source, self.__mode)) cache_spec = popargs('cache_spec', kwargs) call_docval_func(super().write, kwargs) if cache_spec: self.__cache_spec() def __cache_spec(self): ref = self.__file.attrs.get(SPEC_LOC_ATTR) spec_group = None if ref is not None: spec_group = self.__file[ref] else: path = 'specifications' # do something to figure out where the specifications should go spec_group = self.__file.require_group(path) self.__file.attrs[SPEC_LOC_ATTR] = spec_group.ref ns_catalog = self.manager.namespace_catalog for ns_name in ns_catalog.namespaces: ns_builder = self.__convert_namespace(ns_catalog, ns_name) namespace = ns_catalog.get_namespace(ns_name) group_name = '%s/%s' % (ns_name, namespace.version) if group_name in spec_group: continue ns_group = spec_group.create_group(group_name) writer = H5SpecWriter(ns_group) ns_builder.export(self.__ns_spec_path, writer=writer) _export_args = ( {'name': 'src_io', 'type': 'HDMFIO', 'doc': 'the HDMFIO object for reading the data to export'}, {'name': 'container', 'type': Container, 'doc': ('the Container object to export. If None, then the entire contents of the HDMFIO object will be ' 'exported'), 'default': None}, {'name': 'write_args', 'type': dict, 'doc': 'arguments to pass to :py:meth:`write_builder`', 'default': dict()}, {'name': 'cache_spec', 'type': bool, 'doc': 'whether to cache the specification to file', 'default': True} ) @docval(*_export_args) def export(self, **kwargs): """Export data read from a file from any backend to HDF5. See :py:meth:`hdmf.backends.io.HDMFIO.export` for more details. """ if self.__mode != 'w': raise UnsupportedOperation("Cannot export to file %s in mode '%s'. Please use mode 'w'." % (self.source, self.__mode)) src_io = getargs('src_io', kwargs) write_args, cache_spec = popargs('write_args', 'cache_spec', kwargs) if not isinstance(src_io, HDF5IO) and write_args.get('link_data', True): raise UnsupportedOperation("Cannot export from non-HDF5 backend %s to HDF5 with write argument " "link_data=True." % src_io.__class__.__name__) write_args['export_source'] = src_io.source # pass export_source=src_io.source to write_builder ckwargs = kwargs.copy() ckwargs['write_args'] = write_args call_docval_func(super().export, ckwargs) if cache_spec: self.__cache_spec() @classmethod @docval({'name': 'path', 'type': str, 'doc': 'the path to the destination HDF5 file'}, {'name': 'comm', 'type': 'Intracomm', 'doc': 'the MPI communicator to use for parallel I/O', 'default': None}, *_export_args) # NOTE: src_io is required and is the second positional argument def export_io(self, **kwargs): """Export from one backend to HDF5 (class method). Convenience function for :py:meth:`export` where you do not need to instantiate a new `HDF5IO` object for writing. An `HDF5IO` object is created with mode 'w' and the given arguments. Example usage: .. code-block:: python old_io = HDF5IO('old.h5', 'r') HDF5IO.export_io(path='new_copy.h5', src_io=old_io) See :py:meth:`export` for more details. """ path, comm = popargs('path', 'comm', kwargs) with HDF5IO(path=path, comm=comm, mode='w') as write_io: write_io.export(**kwargs) def read(self, **kwargs): if self.__mode == 'w' or self.__mode == 'w-' or self.__mode == 'x': raise UnsupportedOperation("Cannot read from file %s in mode '%s'. Please use mode 'r', 'r+', or 'a'." % (self.source, self.__mode)) try: return call_docval_func(super().read, kwargs) except UnsupportedOperation as e: if str(e) == 'Cannot build data. There are no values.': # pragma: no cover raise UnsupportedOperation("Cannot read data from file %s in mode '%s'. There are no values." % (self.source, self.__mode)) @docval(returns='a GroupBuilder representing the data object', rtype='GroupBuilder') def read_builder(self): if not self.__file: raise UnsupportedOperation("Cannot read data from closed HDF5 file '%s'" % self.source) f_builder = self.__read.get(self.__file) # ignore cached specs when reading builder ignore = set() specloc = self.__file.attrs.get(SPEC_LOC_ATTR) if specloc is not None: ignore.add(self.__file[specloc].name) if f_builder is None: f_builder = self.__read_group(self.__file, ROOT_NAME, ignore=ignore) self.__read[self.__file] = f_builder return f_builder def __set_written(self, builder): """ Mark this builder as written. :param builder: Builder object to be marked as written :type builder: Builder """ builder_id = self.__builderhash(builder) self._written_builders[builder_id] = builder def get_written(self, builder): """Return True if this builder has been written to (or read from) disk by this IO object, False otherwise. :param builder: Builder object to get the written flag for :type builder: Builder :return: True if the builder is found in self._written_builders using the builder ID, False otherwise """ builder_id = self.__builderhash(builder) return builder_id in self._written_builders def __builderhash(self, obj): """Return the ID of a builder for use as a unique hash.""" return id(obj) def __set_built(self, fpath, id, builder): """ Update self.__built to cache the given builder for the given file and id. :param fpath: Path to the HDF5 file containing the object :type fpath: str :param id: ID of the HDF5 object in the path :type id: h5py GroupID object :param builder: The builder to be cached """ self.__built.setdefault(fpath, dict()).setdefault(id, builder) def __get_built(self, fpath, id): """ Look up a builder for the given file and id in self.__built cache :param fpath: Path to the HDF5 file containing the object :type fpath: str :param id: ID of the HDF5 object in the path :type id: h5py GroupID object :return: Builder in the self.__built cache or None """ fdict = self.__built.get(fpath) if fdict: return fdict.get(id) else: return None @docval({'name': 'h5obj', 'type': (Dataset, Group), 'doc': 'the HDF5 object to the corresponding Builder object for'}) def get_builder(self, **kwargs): """ Get the builder for the corresponding h5py Group or Dataset :raises ValueError: When no builder has been constructed yet for the given h5py object """ h5obj = getargs('h5obj', kwargs) fpath = h5obj.file.filename builder = self.__get_built(fpath, h5obj.id) if builder is None: msg = '%s:%s has not been built' % (fpath, h5obj.name) raise ValueError(msg) return builder @docval({'name': 'h5obj', 'type': (Dataset, Group), 'doc': 'the HDF5 object to the corresponding Container/Data object for'}) def get_container(self, **kwargs): """ Get the container for the corresponding h5py Group or Dataset :raises ValueError: When no builder has been constructed yet for the given h5py object """ h5obj = getargs('h5obj', kwargs) builder = self.get_builder(h5obj) container = self.manager.construct(builder) return container def __read_group(self, h5obj, name=None, ignore=set()): kwargs = { "attributes": self.__read_attrs(h5obj), "groups": dict(), "datasets": dict(), "links": dict() } for key, val in kwargs['attributes'].items(): if isinstance(val, bytes): kwargs['attributes'][key] = val.decode('UTF-8') if name is None: name = str(os.path.basename(h5obj.name)) for k in h5obj: sub_h5obj = h5obj.get(k) if not (sub_h5obj is None): if sub_h5obj.name in ignore: continue link_type = h5obj.get(k, getlink=True) if isinstance(link_type, SoftLink) or isinstance(link_type, ExternalLink): # Reading links might be better suited in its own function # get path of link (the key used for tracking what's been built) target_path = link_type.path target_obj = sub_h5obj.file[target_path] builder_name = os.path.basename(target_path) parent_loc = os.path.dirname(target_path) # get builder if already read, else build it builder = self.__get_built(sub_h5obj.file.filename, target_obj.id) if builder is None: # NOTE: all links must have absolute paths if isinstance(target_obj, Dataset): builder = self.__read_dataset(target_obj, builder_name) else: builder = self.__read_group(target_obj, builder_name, ignore=ignore) self.__set_built(sub_h5obj.file.filename, target_obj.id, builder) builder.location = parent_loc link_builder = LinkBuilder(builder, k, source=h5obj.file.filename) self.__set_written(link_builder) kwargs['links'][builder_name] = link_builder if isinstance(link_type, ExternalLink): self.__open_links.append(sub_h5obj) else: builder = self.__get_built(sub_h5obj.file.filename, sub_h5obj.id) obj_type = None read_method = None if isinstance(sub_h5obj, Dataset): read_method = self.__read_dataset obj_type = kwargs['datasets'] else: read_method = partial(self.__read_group, ignore=ignore) obj_type = kwargs['groups'] if builder is None: builder = read_method(sub_h5obj) self.__set_built(sub_h5obj.file.filename, sub_h5obj.id, builder) obj_type[builder.name] = builder else: warnings.warn(os.path.join(h5obj.name, k), BrokenLinkWarning) kwargs['datasets'][k] = None continue kwargs['source'] = h5obj.file.filename ret = GroupBuilder(name, **kwargs) self.__set_written(ret) return ret def __read_dataset(self, h5obj, name=None): kwargs = { "attributes": self.__read_attrs(h5obj), "dtype": h5obj.dtype, "maxshape": h5obj.maxshape } for key, val in kwargs['attributes'].items(): if isinstance(val, bytes): kwargs['attributes'][key] = val.decode('UTF-8') if name is None: name = str(os.path.basename(h5obj.name)) kwargs['source'] = h5obj.file.filename ndims = len(h5obj.shape) if ndims == 0: # read scalar scalar = h5obj[()] if isinstance(scalar, bytes): scalar = scalar.decode('UTF-8') if isinstance(scalar, Reference): # TODO (AJTRITT): This should call __read_ref to support Group references target = h5obj.file[scalar] target_builder = self.__read_dataset(target) self.__set_built(target.file.filename, target.id, target_builder) if isinstance(scalar, RegionReference): d = RegionBuilder(scalar, target_builder) else: d = ReferenceBuilder(target_builder) kwargs['data'] = d kwargs['dtype'] = d.dtype else: kwargs["data"] = scalar else: d = None if h5obj.dtype.kind == 'O' and len(h5obj) > 0: elem1 = h5obj[tuple([0] * (h5obj.ndim - 1) + [0])] if isinstance(elem1, (str, bytes)): d = self._check_str_dtype(h5obj) elif isinstance(elem1, RegionReference): # read list of references d = BuilderH5RegionDataset(h5obj, self) kwargs['dtype'] = d.dtype elif isinstance(elem1, Reference): d = BuilderH5ReferenceDataset(h5obj, self) kwargs['dtype'] = d.dtype elif h5obj.dtype.kind == 'V': # table / compound data type cpd_dt = h5obj.dtype ref_cols = [check_dtype(ref=cpd_dt[i]) or check_dtype(vlen=cpd_dt[i]) for i in range(len(cpd_dt))] d = BuilderH5TableDataset(h5obj, self, ref_cols) kwargs['dtype'] = HDF5IO.__compound_dtype_to_list(h5obj.dtype, d.dtype) else: d = h5obj kwargs["data"] = d ret = DatasetBuilder(name, **kwargs) self.__set_written(ret) return ret def _check_str_dtype(self, h5obj): dtype = h5obj.dtype if dtype.kind == 'O': if dtype.metadata.get('vlen') == str and H5PY_3: return StrDataset(h5obj, None) return h5obj @classmethod def __compound_dtype_to_list(cls, h5obj_dtype, dset_dtype): ret = [] for name, dtype in zip(h5obj_dtype.fields, dset_dtype): ret.append({'name': name, 'dtype': dtype}) return ret def __read_attrs(self, h5obj): ret = dict() for k, v in h5obj.attrs.items(): if k == SPEC_LOC_ATTR: # ignore cached spec continue if isinstance(v, RegionReference): raise ValueError("cannot read region reference attributes yet") elif isinstance(v, Reference): ret[k] = self.__read_ref(h5obj.file[v]) else: ret[k] = v return ret def __read_ref(self, h5obj): ret = None ret = self.__get_built(h5obj.file.filename, h5obj.id) if ret is None: if isinstance(h5obj, Dataset): ret = self.__read_dataset(h5obj) elif isinstance(h5obj, Group): ret = self.__read_group(h5obj) else: raise ValueError("h5obj must be a Dataset or a Group - got %s" % str(h5obj)) self.__set_built(h5obj.file.filename, h5obj.id, ret) return ret def open(self): if self.__file is None: open_flag = self.__mode kwargs = dict() if self.comm: kwargs.update(driver='mpio', comm=self.comm) if self.driver is not None: kwargs.update(driver=self.driver) self.__file = File(self.source, open_flag, **kwargs) def close(self): if self.__file is not None: self.__file.close() def close_linked_files(self): """Close all opened, linked-to files. MacOS and Linux automatically releases the linked-to file after the linking file is closed, but Windows does not, which prevents the linked-to file from being deleted or truncated. Use this method to close all opened, linked-to files. """ for obj in self.__open_links: if obj: obj.file.close() self.__open_links = [] @docval({'name': 'builder', 'type': GroupBuilder, 'doc': 'the GroupBuilder object representing the HDF5 file'}, {'name': 'link_data', 'type': bool, 'doc': 'If not specified otherwise link (True) or copy (False) HDF5 Datasets', 'default': True}, {'name': 'exhaust_dci', 'type': bool, 'doc': 'exhaust DataChunkIterators one at a time. If False, exhaust them concurrently', 'default': True}, {'name': 'export_source', 'type': str, 'doc': 'The source of the builders when exporting', 'default': None}) def write_builder(self, **kwargs): f_builder = popargs('builder', kwargs) link_data, exhaust_dci, export_source = getargs('link_data', 'exhaust_dci', 'export_source', kwargs) self.logger.debug("Writing GroupBuilder '%s' to path '%s' with kwargs=%s" % (f_builder.name, self.source, kwargs)) for name, gbldr in f_builder.groups.items(): self.write_group(self.__file, gbldr, **kwargs) for name, dbldr in f_builder.datasets.items(): self.write_dataset(self.__file, dbldr, **kwargs) for name, lbldr in f_builder.links.items(): self.write_link(self.__file, lbldr) self.set_attributes(self.__file, f_builder.attributes) self.__add_refs() self.__exhaust_dcis() self.__set_written(f_builder) self.logger.debug("Done writing GroupBuilder '%s' to path '%s'" % (f_builder.name, self.source)) def __add_refs(self): ''' Add all references in the file. References get queued to be added at the end of write. This is because the current traversal algorithm (i.e. iterating over GroupBuilder items) does not happen in a guaranteed order. We need to figure out what objects will be references, and then write them after we write everything else. ''' failed = set() while len(self.__ref_queue) > 0: call = self.__ref_queue.popleft() self.logger.debug("Adding reference with call id %d from queue (length %d)" % (id(call), len(self.__ref_queue))) try: call() except KeyError: if id(call) in failed: raise RuntimeError('Unable to resolve reference') self.logger.debug("Adding reference with call id %d failed. Appending call to queue" % id(call)) failed.add(id(call)) self.__ref_queue.append(call) def __exhaust_dcis(self): """ Read and write from any queued DataChunkIterators in a round-robin fashion """ while len(self.__dci_queue) > 0: self.logger.debug("Exhausting DataChunkIterator from queue (length %d)" % len(self.__dci_queue)) dset, data = self.__dci_queue.popleft() if self.__write_chunk__(dset, data): self.__dci_queue.append((dset, data)) @classmethod def get_type(cls, data): if isinstance(data, str): return H5_TEXT elif isinstance(data, Container): return H5_REF elif not hasattr(data, '__len__'): return type(data) else: if len(data) == 0: if hasattr(data, 'dtype'): return data.dtype else: raise ValueError('cannot determine type for empty data') return cls.get_type(data[0]) __dtypes = { "float": np.float32, "float32": np.float32, "double": np.float64, "float64": np.float64, "long": np.int64, "int64": np.int64, "int": np.int32, "int32": np.int32, "short": np.int16, "int16": np.int16, "int8": np.int8, "uint64": np.uint64, "uint": np.uint32, "uint32": np.uint32, "uint16": np.uint16, "uint8": np.uint8, "bool": np.bool_, "text": H5_TEXT, "utf": H5_TEXT, "utf8": H5_TEXT, "utf-8": H5_TEXT, "ascii": H5_BINARY, "bytes": H5_BINARY, "ref": H5_REF, "reference": H5_REF, "object": H5_REF, "region": H5_REGREF, "isodatetime": H5_TEXT, "datetime": H5_TEXT, } @classmethod def __resolve_dtype__(cls, dtype, data): # TODO: These values exist, but I haven't solved them yet # binary # number dtype = cls.__resolve_dtype_helper__(dtype) if dtype is None: dtype = cls.get_type(data) return dtype @classmethod def __resolve_dtype_helper__(cls, dtype): if dtype is None: return None elif isinstance(dtype, str): return cls.__dtypes.get(dtype) elif isinstance(dtype, dict): return cls.__dtypes.get(dtype['reftype']) elif isinstance(dtype, np.dtype): # NOTE: some dtypes may not be supported, but we need to support writing of read-in compound types return dtype else: return np.dtype([(x['name'], cls.__resolve_dtype_helper__(x['dtype'])) for x in dtype]) @docval({'name': 'obj', 'type': (Group, Dataset), 'doc': 'the HDF5 object to add attributes to'}, {'name': 'attributes', 'type': dict, 'doc': 'a dict containing the attributes on the Group or Dataset, indexed by attribute name'}) def set_attributes(self, **kwargs): obj, attributes = getargs('obj', 'attributes', kwargs) for key, value in attributes.items(): try: if isinstance(value, (set, list, tuple)): tmp = tuple(value) if len(tmp) > 0: if isinstance(tmp[0], (str, bytes)): value = np.array(value, dtype=special_dtype(vlen=type(tmp[0]))) elif isinstance(tmp[0], Container): # a list of references self.__queue_ref(self._make_attr_ref_filler(obj, key, tmp)) else: value = np.array(value) self.logger.debug("Setting %s '%s' attribute '%s' to %s" % (obj.__class__.__name__, obj.name, key, value.__class__.__name__)) obj.attrs[key] = value elif isinstance(value, (Container, Builder, ReferenceBuilder)): # a reference self.__queue_ref(self._make_attr_ref_filler(obj, key, value)) else: self.logger.debug("Setting %s '%s' attribute '%s' to %s" % (obj.__class__.__name__, obj.name, key, value.__class__.__name__)) if isinstance(value, np.ndarray) and value.dtype.kind == 'U': value = np.array(value, dtype=H5_TEXT) obj.attrs[key] = value # a regular scalar except Exception as e: msg = "unable to write attribute '%s' on object '%s'" % (key, obj.name) raise RuntimeError(msg) from e def _make_attr_ref_filler(self, obj, key, value): ''' Make the callable for setting references to attributes ''' self.logger.debug("Queueing set %s '%s' attribute '%s' to %s" % (obj.__class__.__name__, obj.name, key, value.__class__.__name__)) if isinstance(value, (tuple, list)): def _filler(): ret = list() for item in value: ret.append(self.__get_ref(item)) obj.attrs[key] = ret else: def _filler(): obj.attrs[key] = self.__get_ref(value) return _filler @docval({'name': 'parent', 'type': Group, 'doc': 'the parent HDF5 object'}, {'name': 'builder', 'type': GroupBuilder, 'doc': 'the GroupBuilder to write'}, {'name': 'link_data', 'type': bool, 'doc': 'If not specified otherwise link (True) or copy (False) HDF5 Datasets', 'default': True}, {'name': 'exhaust_dci', 'type': bool, 'doc': 'exhaust DataChunkIterators one at a time. If False, exhaust them concurrently', 'default': True}, {'name': 'export_source', 'type': str, 'doc': 'The source of the builders when exporting', 'default': None}, returns='the Group that was created', rtype='Group') def write_group(self, **kwargs): parent, builder = popargs('parent', 'builder', kwargs) self.logger.debug("Writing GroupBuilder '%s' to parent group '%s'" % (builder.name, parent.name)) if self.get_written(builder): self.logger.debug(" GroupBuilder '%s' is already written" % builder.name) group = parent[builder.name] else: self.logger.debug(" Creating group '%s'" % builder.name) group = parent.create_group(builder.name) # write all groups subgroups = builder.groups if subgroups: for subgroup_name, sub_builder in subgroups.items(): # do not create an empty group without attributes or links self.write_group(group, sub_builder, **kwargs) # write all datasets datasets = builder.datasets if datasets: for dset_name, sub_builder in datasets.items(): self.write_dataset(group, sub_builder, **kwargs) # write all links links = builder.links if links: for link_name, sub_builder in links.items(): self.write_link(group, sub_builder) attributes = builder.attributes self.set_attributes(group, attributes) self.__set_written(builder) return group def __get_path(self, builder): """Get the path to the builder. Note that the root of the file has no name - it is just "/". Thus, the name of the root container is ignored. """ curr = builder names = list() while curr.parent is not None: names.append(curr.name) curr = curr.parent delim = "/" path = "%s%s" % (delim, delim.join(reversed(names))) return path @docval({'name': 'parent', 'type': Group, 'doc': 'the parent HDF5 object'}, {'name': 'builder', 'type': LinkBuilder, 'doc': 'the LinkBuilder to write'}, returns='the Link that was created', rtype='Link') def write_link(self, **kwargs): parent, builder = getargs('parent', 'builder', kwargs) self.logger.debug("Writing LinkBuilder '%s' to parent group '%s'" % (builder.name, parent.name)) if self.get_written(builder): self.logger.debug(" LinkBuilder '%s' is already written" % builder.name) return None name = builder.name target_builder = builder.builder path = self.__get_path(target_builder) # source will indicate target_builder's location if builder.source == target_builder.source: link_obj = SoftLink(path) self.logger.debug(" Creating SoftLink '%s/%s' to '%s'" % (parent.name, name, link_obj.path)) elif target_builder.source is not None: target_filename = os.path.abspath(target_builder.source) parent_filename = os.path.abspath(parent.file.filename) relative_path = os.path.relpath(target_filename, os.path.dirname(parent_filename)) if target_builder.location is not None: path = target_builder.location + "/" + target_builder.name link_obj = ExternalLink(relative_path, path) self.logger.debug(" Creating ExternalLink '%s/%s' to '%s://%s'" % (parent.name, name, link_obj.filename, link_obj.path)) else: msg = 'cannot create external link to %s' % path raise ValueError(msg) parent[name] = link_obj self.__set_written(builder) return link_obj @docval({'name': 'parent', 'type': Group, 'doc': 'the parent HDF5 object'}, # noqa: C901 {'name': 'builder', 'type': DatasetBuilder, 'doc': 'the DatasetBuilder to write'}, {'name': 'link_data', 'type': bool, 'doc': 'If not specified otherwise link (True) or copy (False) HDF5 Datasets', 'default': True}, {'name': 'exhaust_dci', 'type': bool, 'doc': 'exhaust DataChunkIterators one at a time. If False, exhaust them concurrently', 'default': True}, {'name': 'export_source', 'type': str, 'doc': 'The source of the builders when exporting', 'default': None}, returns='the Dataset that was created', rtype=Dataset) def write_dataset(self, **kwargs): # noqa: C901 """ Write a dataset to HDF5 The function uses other dataset-dependent write functions, e.g, `__scalar_fill__`, `__list_fill__`, and `__setup_chunked_dset__` to write the data. """ parent, builder = popargs('parent', 'builder', kwargs) link_data, exhaust_dci, export_source = getargs('link_data', 'exhaust_dci', 'export_source', kwargs) self.logger.debug("Writing DatasetBuilder '%s' to parent group '%s'" % (builder.name, parent.name)) if self.get_written(builder): self.logger.debug(" DatasetBuilder '%s' is already written" % builder.name) return None name = builder.name data = builder.data options = dict() # dict with additional if isinstance(data, H5DataIO): options['io_settings'] = data.io_settings link_data = data.link_data data = data.data else: options['io_settings'] = {} attributes = builder.attributes options['dtype'] = builder.dtype dset = None link = None # The user provided an existing h5py dataset as input and asked to create a link to the dataset if isinstance(data, Dataset): data_filename = os.path.abspath(data.file.filename) if link_data: if export_source is None: # not exporting parent_filename = os.path.abspath(parent.file.filename) if data_filename != parent_filename: # create external link to data relative_path = os.path.relpath(data_filename, os.path.dirname(parent_filename)) link = ExternalLink(relative_path, data.name) self.logger.debug(" Creating ExternalLink '%s/%s' to '%s://%s'" % (parent.name, name, link.filename, link.path)) else: # create soft link to dataset already in this file -- possible if mode == 'r+' link = SoftLink(data.name) self.logger.debug(" Creating SoftLink '%s/%s' to '%s'" % (parent.name, name, link.path)) parent[name] = link else: # exporting export_source = os.path.abspath(export_source) parent_filename = os.path.abspath(parent.file.filename) if data_filename != export_source: # dataset is in different file than export source # possible if user adds a link to a dataset in a different file after reading export source # to memory relative_path = os.path.relpath(data_filename, os.path.dirname(parent_filename)) link = ExternalLink(relative_path, data.name) self.logger.debug(" Creating ExternalLink '%s/%s' to '%s://%s'" % (parent.name, name, link.filename, link.path)) parent[name] = link elif parent.name != data.parent.name: # dataset is in export source and has different path # so create a soft link to the dataset in this file # possible if user adds a link to a dataset in export source after reading to memory link = SoftLink(data.name) self.logger.debug(" Creating SoftLink '%s/%s' to '%s'" % (parent.name, name, link.path)) parent[name] = link else: # dataset is in export source and has same path as the builder, so copy the dataset self.logger.debug(" Copying data from '%s://%s' to '%s/%s'" % (data.file.filename, data.name, parent.name, name)) parent.copy(source=data, dest=parent, name=name, expand_soft=False, expand_external=False, expand_refs=False, without_attrs=True) dset = parent[name] else: # TODO add option for case where there are multiple links to the same dataset within a file: # instead of copying the dset N times, copy it once and create soft links to it within the file self.logger.debug(" Copying data from '%s://%s' to '%s/%s'" % (data.file.filename, data.name, parent.name, name)) parent.copy(source=data, dest=parent, name=name, expand_soft=False, expand_external=False, expand_refs=False, without_attrs=True) dset = parent[name] # Write a compound dataset, i.e, a dataset with compound data type elif isinstance(options['dtype'], list): # do some stuff to figure out what data is a reference refs = list() for i, dts in enumerate(options['dtype']): if self.__is_ref(dts): refs.append(i) # If one ore more of the parts of the compound data type are references then we need to deal with those if len(refs) > 0: try: _dtype = self.__resolve_dtype__(options['dtype'], data) except Exception as exc: msg = 'cannot add %s to %s - could not determine type' % (name, parent.name) raise Exception(msg) from exc dset = parent.require_dataset(name, shape=(len(data),), dtype=_dtype, **options['io_settings']) self.__set_written(builder) self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing " "object references. attributes: %s" % (name, list(attributes.keys()))) @self.__queue_ref def _filler(): self.logger.debug("Resolving object references and setting attribute on dataset '%s' " "containing attributes: %s" % (name, list(attributes.keys()))) ret = list() for item in data: new_item = list(item) for i in refs: new_item[i] = self.__get_ref(item[i]) ret.append(tuple(new_item)) dset = parent[name] dset[:] = ret self.set_attributes(dset, attributes) return # If the compound data type contains only regular data (i.e., no references) then we can write it as usual else: dset = self.__list_fill__(parent, name, data, options) # Write a dataset containing references, i.e., a region or object reference. # NOTE: we can ignore options['io_settings'] for scalar data elif self.__is_ref(options['dtype']): _dtype = self.__dtypes.get(options['dtype']) # Write a scalar data region reference dataset if isinstance(data, RegionBuilder): dset = parent.require_dataset(name, shape=(), dtype=_dtype) self.__set_written(builder) self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing a " "region reference. attributes: %s" % (name, list(attributes.keys()))) @self.__queue_ref def _filler(): self.logger.debug("Resolving region reference and setting attribute on dataset '%s' " "containing attributes: %s" % (name, list(attributes.keys()))) ref = self.__get_ref(data.builder, data.region) dset = parent[name] dset[()] = ref self.set_attributes(dset, attributes) # Write a scalar object reference dataset elif isinstance(data, ReferenceBuilder): dset = parent.require_dataset(name, dtype=_dtype, shape=()) self.__set_written(builder) self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing an " "object reference. attributes: %s" % (name, list(attributes.keys()))) @self.__queue_ref def _filler(): self.logger.debug("Resolving object reference and setting attribute on dataset '%s' " "containing attributes: %s" % (name, list(attributes.keys()))) ref = self.__get_ref(data.builder) dset = parent[name] dset[()] = ref self.set_attributes(dset, attributes) # Write an array dataset of references else: # Write a array of region references if options['dtype'] == 'region': dset = parent.require_dataset(name, dtype=_dtype, shape=(len(data),), **options['io_settings']) self.__set_written(builder) self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing " "region references. attributes: %s" % (name, list(attributes.keys()))) @self.__queue_ref def _filler(): self.logger.debug("Resolving region references and setting attribute on dataset '%s' " "containing attributes: %s" % (name, list(attributes.keys()))) refs = list() for item in data: refs.append(self.__get_ref(item.builder, item.region)) dset = parent[name] dset[()] = refs self.set_attributes(dset, attributes) # Write array of object references else: dset = parent.require_dataset(name, shape=(len(data),), dtype=_dtype, **options['io_settings']) self.__set_written(builder) self.logger.debug("Queueing reference resolution and set attribute on dataset '%s' containing " "object references. attributes: %s" % (name, list(attributes.keys()))) @self.__queue_ref def _filler(): self.logger.debug("Resolving object references and setting attribute on dataset '%s' " "containing attributes: %s" % (name, list(attributes.keys()))) refs = list() for item in data: refs.append(self.__get_ref(item)) dset = parent[name] dset[()] = refs self.set_attributes(dset, attributes) return # write a "regular" dataset else: # Write a scalar dataset containing a single string if isinstance(data, (str, bytes)): dset = self.__scalar_fill__(parent, name, data, options) # Iterative write of a data chunk iterator elif isinstance(data, AbstractDataChunkIterator): dset = self.__setup_chunked_dset__(parent, name, data, options) self.__dci_queue.append((dset, data)) # Write a regular in memory array (e.g., numpy array, list etc.) elif hasattr(data, '__len__'): dset = self.__list_fill__(parent, name, data, options) # Write a regular scalar dataset else: dset = self.__scalar_fill__(parent, name, data, options) # Create the attributes on the dataset only if we are the primary and not just a Soft/External link if link is None: self.set_attributes(dset, attributes) # Validate the attributes on the linked dataset elif len(attributes) > 0: pass self.__set_written(builder) if exhaust_dci: self.__exhaust_dcis() @classmethod def __scalar_fill__(cls, parent, name, data, options=None): dtype = None io_settings = {} if options is not None: dtype = options.get('dtype') io_settings = options.get('io_settings') if not isinstance(dtype, type): try: dtype = cls.__resolve_dtype__(dtype, data) except Exception as exc: msg = 'cannot add %s to %s - could not determine type' % (name, parent.name) raise Exception(msg) from exc try: dset = parent.create_dataset(name, data=data, shape=None, dtype=dtype, **io_settings) except Exception as exc: msg = "Could not create scalar dataset %s in %s" % (name, parent.name) raise Exception(msg) from exc return dset @classmethod def __setup_chunked_dset__(cls, parent, name, data, options=None): """ Setup a dataset for writing to one-chunk-at-a-time based on the given DataChunkIterator :param parent: The parent object to which the dataset should be added :type parent: h5py.Group, h5py.File :param name: The name of the dataset :type name: str :param data: The data to be written. :type data: DataChunkIterator :param options: Dict with options for creating a dataset. available options are 'dtype' and 'io_settings' :type options: dict """ io_settings = {} if options is not None: if 'io_settings' in options: io_settings = options.get('io_settings') # Define the chunking options if the user has not set them explicitly. We need chunking for the iterative write. if 'chunks' not in io_settings: recommended_chunks = data.recommended_chunk_shape() io_settings['chunks'] = True if recommended_chunks is None else recommended_chunks # Define the shape of the data if not provided by the user if 'shape' not in io_settings: io_settings['shape'] = data.recommended_data_shape() # Define the maxshape of the data if not provided by the user if 'maxshape' not in io_settings: io_settings['maxshape'] = data.maxshape if 'dtype' not in io_settings: if (options is not None) and ('dtype' in options): io_settings['dtype'] = options['dtype'] else: io_settings['dtype'] = data.dtype if isinstance(io_settings['dtype'], str): # map to real dtype if we were given a string io_settings['dtype'] = cls.__dtypes.get(io_settings['dtype']) try: dset = parent.create_dataset(name, **io_settings) except Exception as exc: raise Exception("Could not create dataset %s in %s" % (name, parent.name)) from exc return dset @classmethod def __write_chunk__(cls, dset, data): """ Read a chunk from the given DataChunkIterator and write it to the given Dataset :param dset: The Dataset to write to :type dset: Dataset :param data: The DataChunkIterator to read from :type data: DataChunkIterator :return: True of a chunk was written, False otherwise :rtype: bool """ try: chunk_i = next(data) except StopIteration: return False if isinstance(chunk_i.selection, tuple): # Determine the minimum array dimensions to fit the chunk selection max_bounds = tuple([x.stop or 0 if isinstance(x, slice) else x+1 for x in chunk_i.selection]) elif isinstance(chunk_i.selection, int): max_bounds = (chunk_i.selection+1, ) elif isinstance(chunk_i.selection, slice): max_bounds = (chunk_i.selection.stop or 0, ) else: msg = ("Chunk selection %s must be a single int, single slice, or tuple of slices " "and/or integers") % str(chunk_i.selection) raise TypeError(msg) # Expand the dataset if needed dset.id.extend(max_bounds) # Write the data dset[chunk_i.selection] = chunk_i.data return True @classmethod def __chunked_iter_fill__(cls, parent, name, data, options=None): """ Write data to a dataset one-chunk-at-a-time based on the given DataChunkIterator :param parent: The parent object to which the dataset should be added :type parent: h5py.Group, h5py.File :param name: The name of the dataset :type name: str :param data: The data to be written. :type data: DataChunkIterator :param options: Dict with options for creating a dataset. available options are 'dtype' and 'io_settings' :type options: dict """ dset = cls.__setup_chunked_dset__(parent, name, data, options=options) read = True while read: read = cls.__write_chunk__(dset, data) return dset @classmethod def __list_fill__(cls, parent, name, data, options=None): # define the io settings and data type if necessary io_settings = {} dtype = None if options is not None: dtype = options.get('dtype') io_settings = options.get('io_settings') if not isinstance(dtype, type): try: dtype = cls.__resolve_dtype__(dtype, data) except Exception as exc: msg = 'cannot add %s to %s - could not determine type' % (name, parent.name) raise Exception(msg) from exc # define the data shape if 'shape' in io_settings: data_shape = io_settings.pop('shape') elif hasattr(data, 'shape'): data_shape = data.shape elif isinstance(dtype, np.dtype): data_shape = (len(data),) else: data_shape = get_data_shape(data) # Create the dataset try: dset = parent.create_dataset(name, shape=data_shape, dtype=dtype, **io_settings) except Exception as exc: msg = "Could not create dataset %s in %s with shape %s, dtype %s, and iosettings %s. %s" % \ (name, parent.name, str(data_shape), str(dtype), str(io_settings), str(exc)) raise Exception(msg) from exc # Write the data if len(data) > dset.shape[0]: new_shape = list(dset.shape) new_shape[0] = len(data) dset.resize(new_shape) try: dset[:] = data except Exception as e: raise e return dset @docval({'name': 'container', 'type': (Builder, Container, ReferenceBuilder), 'doc': 'the object to reference', 'default': None}, {'name': 'region', 'type': (slice, list, tuple), 'doc': 'the region reference indexing object', 'default': None}, returns='the reference', rtype=Reference) def __get_ref(self, **kwargs): container, region = getargs('container', 'region', kwargs) if container is None: return None if isinstance(container, Builder): self.logger.debug("Getting reference for %s '%s'" % (container.__class__.__name__, container.name)) if isinstance(container, LinkBuilder): builder = container.target_builder else: builder = container elif isinstance(container, ReferenceBuilder): self.logger.debug("Getting reference for %s '%s'" % (container.__class__.__name__, container.builder.name)) builder = container.builder else: self.logger.debug("Getting reference for %s '%s'" % (container.__class__.__name__, container.name)) builder = self.manager.build(container) path = self.__get_path(builder) self.logger.debug("Getting reference at path '%s'" % path) if isinstance(container, RegionBuilder): region = container.region if region is not None: dset = self.__file[path] if not isinstance(dset, Dataset): raise ValueError('cannot create region reference without Dataset') return self.__file[path].regionref[region] else: return self.__file[path].ref def __is_ref(self, dtype): if isinstance(dtype, DtypeSpec): return self.__is_ref(dtype.dtype) if isinstance(dtype, RefSpec): return True if isinstance(dtype, dict): # may be dict from reading a compound dataset return self.__is_ref(dtype['dtype']) if isinstance(dtype, str): return dtype == DatasetBuilder.OBJECT_REF_TYPE or dtype == DatasetBuilder.REGION_REF_TYPE return False def __queue_ref(self, func): '''Set aside filling dset with references dest[sl] = func() Args: dset: the h5py.Dataset that the references need to be added to sl: the np.s_ (slice) object for indexing into dset func: a function to call to return the chunk of data, with references filled in ''' # TODO: come up with more intelligent way of # queueing reference resolution, based on reference # dependency self.__ref_queue.append(func) def __rec_get_ref(self, ref_list): ret = list() for elem in ref_list: if isinstance(elem, (list, tuple)): ret.append(self.__rec_get_ref(elem)) elif isinstance(elem, (Builder, Container)): ret.append(self.__get_ref(elem)) else: ret.append(elem) return ret @property def mode(self): """ Return the HDF5 file mode. One of ("w", "r", "r+", "a", "w-", "x"). """ return self.__mode @classmethod @docval(*get_docval(H5DataIO.__init__)) def set_dataio(cls, **kwargs): """ Wrap the given Data object with an H5DataIO. This method is provided merely for convenience. It is the equivalent of the following: ``` from hdmf.backends.hdf5 import H5DataIO data = ... data = H5DataIO(data) ``` """ cargs, ckwargs = fmt_docval_args(H5DataIO.__init__, kwargs) return H5DataIO(*cargs, **ckwargs) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/backends/io.py0000644000655200065520000001257500000000000017033 0ustar00circlecicirclecifrom abc import ABCMeta, abstractmethod from pathlib import Path from ..build import BuildManager, GroupBuilder from ..container import Container from ..utils import docval, getargs, popargs class HDMFIO(metaclass=ABCMeta): @docval({'name': 'manager', 'type': BuildManager, 'doc': 'the BuildManager to use for I/O', 'default': None}, {"name": "source", "type": (str, Path), "doc": "the source of container being built i.e. file path", 'default': None}) def __init__(self, **kwargs): manager, source = getargs('manager', 'source', kwargs) if isinstance(source, Path): source = str(source) self.__manager = manager self.__built = dict() self.__source = source self.open() @property def manager(self): '''The BuildManager this instance is using''' return self.__manager @property def source(self): '''The source of the container being read/written i.e. file path''' return self.__source @docval(returns='the Container object that was read in', rtype=Container) def read(self, **kwargs): """Read a container from the IO source.""" f_builder = self.read_builder() if all(len(v) == 0 for v in f_builder.values()): # TODO also check that the keys are appropriate. print a better error message raise UnsupportedOperation('Cannot build data. There are no values.') container = self.__manager.construct(f_builder) return container @docval({'name': 'container', 'type': Container, 'doc': 'the Container object to write'}, allow_extra=True) def write(self, **kwargs): """Write a container to the IO source.""" container = popargs('container', kwargs) f_builder = self.__manager.build(container, source=self.__source, root=True) self.write_builder(f_builder, **kwargs) @docval({'name': 'src_io', 'type': 'HDMFIO', 'doc': 'the HDMFIO object for reading the data to export'}, {'name': 'container', 'type': Container, 'doc': ('the Container object to export. If None, then the entire contents of the HDMFIO object will be ' 'exported'), 'default': None}, {'name': 'write_args', 'type': dict, 'doc': 'arguments to pass to :py:meth:`write_builder`', 'default': dict()}) def export(self, **kwargs): """Export from one backend to the backend represented by this class. If `container` is provided, then the build manager of `src_io` is used to build the container, and the resulting builder will be exported to the new backend. So if `container` is provided, `src_io` must have a non-None manager property. If `container` is None, then the contents of `src_io` will be read and exported to the new backend. The provided container must be the root of the hierarchy of the source used to read the container (i.e., you cannot read a file and export a part of that file. Arguments can be passed in for the `write_builder` method using `write_args`. Some arguments may not be supported during export. Example usage: .. code-block:: python old_io = HDF5IO('old.nwb', 'r') with HDF5IO('new_copy.nwb', 'w') as new_io: new_io.export(old_io) """ src_io, container, write_args = getargs('src_io', 'container', 'write_args', kwargs) if container is not None: # check that manager exists, container was built from manager, and container is root of hierarchy if src_io.manager is None: raise ValueError('When a container is provided, src_io must have a non-None manager (BuildManager) ' 'property.') old_bldr = src_io.manager.get_builder(container) if old_bldr is None: raise ValueError('The provided container must have been read by the provided src_io.') if old_bldr.parent is not None: raise ValueError('The provided container must be the root of the hierarchy of the ' 'source used to read the container.') # build any modified containers src_io.manager.purge_outdated() bldr = src_io.manager.build(container, source=self.__source, root=True, export=True) else: bldr = src_io.read_builder() self.write_builder(builder=bldr, **write_args) @abstractmethod @docval(returns='a GroupBuilder representing the read data', rtype='GroupBuilder') def read_builder(self): ''' Read data and return the GroupBuilder representing it ''' pass @abstractmethod @docval({'name': 'builder', 'type': GroupBuilder, 'doc': 'the GroupBuilder object representing the Container'}, allow_extra=True) def write_builder(self, **kwargs): ''' Write a GroupBuilder representing an Container object ''' pass @abstractmethod def open(self): ''' Open this HDMFIO object for writing of the builder ''' pass @abstractmethod def close(self): ''' Close this HDMFIO object to further reading/writing''' pass def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() class UnsupportedOperation(ValueError): pass ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/backends/warnings.py0000644000655200065520000000016400000000000020243 0ustar00circlecicircleciclass BrokenLinkWarning(UserWarning): """ Raised when a group has a key with a None value. """ pass ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1766272 hdmf-3.1.1/src/hdmf/build/0000755000655200065520000000000000000000000015365 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/build/__init__.py0000644000655200065520000000113400000000000017475 0ustar00circlecicirclecifrom .builders import Builder, DatasetBuilder, GroupBuilder, LinkBuilder, ReferenceBuilder, RegionBuilder from .classgenerator import CustomClassGenerator, MCIClassGenerator from .errors import (BuildError, OrphanContainerBuildError, ReferenceTargetNotBuiltError, ContainerConfigurationError, ConstructError) from .manager import BuildManager, TypeMap from .objectmapper import ObjectMapper from .warnings import (BuildWarning, MissingRequiredBuildWarning, DtypeConversionWarning, IncorrectQuantityBuildWarning, MissingRequiredWarning, OrphanContainerWarning) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/build/builders.py0000644000655200065520000004105100000000000017551 0ustar00circlecicircleciimport copy as _copy import itertools as _itertools import posixpath as _posixpath from abc import ABCMeta from collections.abc import Iterable from datetime import datetime import numpy as np from h5py import RegionReference from ..utils import docval, getargs, get_docval class Builder(dict, metaclass=ABCMeta): @docval({'name': 'name', 'type': str, 'doc': 'the name of the group'}, {'name': 'parent', 'type': 'Builder', 'doc': 'the parent builder of this Builder', 'default': None}, {'name': 'source', 'type': str, 'doc': 'the source of the data in this builder e.g. file name', 'default': None}) def __init__(self, **kwargs): name, parent, source = getargs('name', 'parent', 'source', kwargs) super().__init__() self.__name = name self.__parent = parent if source is not None: self.__source = source elif parent is not None: self.__source = parent.source else: self.__source = None @property def path(self): """The path of this builder.""" s = list() c = self while c is not None: s.append(c.name) c = c.parent return "/".join(s[::-1]) @property def name(self): """The name of this builder.""" return self.__name @property def source(self): """The source of this builder.""" return self.__source @source.setter def source(self, s): if self.__source is not None: raise AttributeError('Cannot overwrite source.') self.__source = s @property def parent(self): """The parent builder of this builder.""" return self.__parent @parent.setter def parent(self, p): if self.__parent is not None: raise AttributeError('Cannot overwrite parent.') self.__parent = p if self.__source is None: self.source = p.source def __repr__(self): ret = "%s %s %s" % (self.path, self.__class__.__name__, super().__repr__()) return ret class BaseBuilder(Builder, metaclass=ABCMeta): __attribute = 'attributes' # self dictionary key for attributes @docval({'name': 'name', 'type': str, 'doc': 'The name of the builder.'}, {'name': 'attributes', 'type': dict, 'doc': 'A dictionary of attributes to create in this builder.', 'default': dict()}, {'name': 'parent', 'type': 'GroupBuilder', 'doc': 'The parent builder of this builder.', 'default': None}, {'name': 'source', 'type': str, 'doc': 'The source of the data represented in this builder', 'default': None}) def __init__(self, **kwargs): name, attributes, parent, source = getargs('name', 'attributes', 'parent', 'source', kwargs) super().__init__(name, parent, source) super().__setitem__(BaseBuilder.__attribute, dict()) for name, val in attributes.items(): self.set_attribute(name, val) self.__location = None @property def location(self): """The location of this Builder in its source.""" return self.__location @location.setter def location(self, val): self.__location = val @property def attributes(self): """The attributes stored in this Builder object.""" return super().__getitem__(BaseBuilder.__attribute) @docval({'name': 'name', 'type': str, 'doc': 'The name of the attribute.'}, {'name': 'value', 'type': None, 'doc': 'The attribute value.'}) def set_attribute(self, **kwargs): """Set an attribute for this group.""" name, value = getargs('name', 'value', kwargs) self.attributes[name] = value class GroupBuilder(BaseBuilder): # sub-dictionary keys. subgroups go in super().__getitem__(GroupBuilder.__group) __group = 'groups' __dataset = 'datasets' __link = 'links' __attribute = 'attributes' @docval({'name': 'name', 'type': str, 'doc': 'The name of the group.'}, {'name': 'groups', 'type': (dict, list), 'doc': ('A dictionary or list of subgroups to add to this group. If a dict is provided, only the ' 'values are used.'), 'default': dict()}, {'name': 'datasets', 'type': (dict, list), 'doc': ('A dictionary or list of datasets to add to this group. If a dict is provided, only the ' 'values are used.'), 'default': dict()}, {'name': 'attributes', 'type': dict, 'doc': 'A dictionary of attributes to create in this group.', 'default': dict()}, {'name': 'links', 'type': (dict, list), 'doc': ('A dictionary or list of links to add to this group. If a dict is provided, only the ' 'values are used.'), 'default': dict()}, {'name': 'parent', 'type': 'GroupBuilder', 'doc': 'The parent builder of this builder.', 'default': None}, {'name': 'source', 'type': str, 'doc': 'The source of the data represented in this builder.', 'default': None}) def __init__(self, **kwargs): """Create a builder object for a group.""" name, groups, datasets, links, attributes, parent, source = getargs( 'name', 'groups', 'datasets', 'links', 'attributes', 'parent', 'source', kwargs) # NOTE: if groups, datasets, or links are dicts, their keys are unused groups = self.__to_list(groups) datasets = self.__to_list(datasets) links = self.__to_list(links) # dictionary mapping subgroup/dataset/attribute/link name to the key that maps to the # subgroup/dataset/attribute/link sub-dictionary that maps the name to the builder self.obj_type = dict() super().__init__(name, attributes, parent, source) super().__setitem__(GroupBuilder.__group, dict()) super().__setitem__(GroupBuilder.__dataset, dict()) super().__setitem__(GroupBuilder.__link, dict()) for group in groups: self.set_group(group) for dataset in datasets: if dataset is not None: self.set_dataset(dataset) for link in links: self.set_link(link) def __to_list(self, d): if isinstance(d, dict): return list(d.values()) return d @property def source(self): ''' The source of this Builder ''' return super().source @source.setter def source(self, s): """Recursively set all subgroups/datasets/links source when this source is set.""" super(GroupBuilder, self.__class__).source.fset(self, s) for group in self.groups.values(): if group.source is None: group.source = s for dset in self.datasets.values(): if dset.source is None: dset.source = s for link in self.links.values(): if link.source is None: link.source = s @property def groups(self): """The subgroups contained in this group.""" return super().__getitem__(GroupBuilder.__group) @property def datasets(self): """The datasets contained in this group.""" return super().__getitem__(GroupBuilder.__dataset) @property def links(self): """The links contained in this group.""" return super().__getitem__(GroupBuilder.__link) @docval(*get_docval(BaseBuilder.set_attribute)) def set_attribute(self, **kwargs): """Set an attribute for this group.""" name, value = getargs('name', 'value', kwargs) self.__check_obj_type(name, GroupBuilder.__attribute) super().set_attribute(name, value) self.obj_type[name] = GroupBuilder.__attribute def __check_obj_type(self, name, obj_type): # check that the name is not associated with a different object type in this group if name in self.obj_type and self.obj_type[name] != obj_type: raise ValueError("'%s' already exists in %s.%s, cannot set in %s." % (name, self.name, self.obj_type[name], obj_type)) @docval({'name': 'builder', 'type': 'GroupBuilder', 'doc': 'The GroupBuilder to add to this group.'}) def set_group(self, **kwargs): """Add a subgroup to this group.""" builder = getargs('builder', kwargs) self.__set_builder(builder, GroupBuilder.__group) @docval({'name': 'builder', 'type': 'DatasetBuilder', 'doc': 'The DatasetBuilder to add to this group.'}) def set_dataset(self, **kwargs): """Add a dataset to this group.""" builder = getargs('builder', kwargs) self.__set_builder(builder, GroupBuilder.__dataset) @docval({'name': 'builder', 'type': 'LinkBuilder', 'doc': 'The LinkBuilder to add to this group.'}) def set_link(self, **kwargs): """Add a link to this group.""" builder = getargs('builder', kwargs) self.__set_builder(builder, GroupBuilder.__link) def __set_builder(self, builder, obj_type): name = builder.name self.__check_obj_type(name, obj_type) super().__getitem__(obj_type)[name] = builder self.obj_type[name] = obj_type if builder.parent is None: builder.parent = self def is_empty(self): """Returns true if there are no datasets, links, attributes, and non-empty subgroups. False otherwise.""" if len(self.datasets) or len(self.links) or len(self.attributes): return False elif len(self.groups): return all(g.is_empty() for g in self.groups.values()) else: return True def __getitem__(self, key): """Like dict.__getitem__, but looks in groups, datasets, attributes, and links sub-dictionaries. Key can be a posix path to a sub-builder. """ try: key_ar = _posixpath.normpath(key).split('/') return self.__get_rec(key_ar) except KeyError: raise KeyError(key) def get(self, key, default=None): """Like dict.get, but looks in groups, datasets, attributes, and links sub-dictionaries. Key can be a posix path to a sub-builder. """ try: key_ar = _posixpath.normpath(key).split('/') return self.__get_rec(key_ar) except KeyError: return default def __get_rec(self, key_ar): # recursive helper for __getitem__ and get if len(key_ar) == 1: # get the correct dictionary (groups, datasets, links, attributes) associated with the key # then look up the key within that dictionary to get the builder return super().__getitem__(self.obj_type[key_ar[0]])[key_ar[0]] else: if key_ar[0] in self.groups: return self.groups[key_ar[0]].__get_rec(key_ar[1:]) raise KeyError(key_ar[0]) def __setitem__(self, args, val): raise NotImplementedError('__setitem__') def __contains__(self, item): return self.obj_type.__contains__(item) def items(self): """Like dict.items, but iterates over items in groups, datasets, attributes, and links sub-dictionaries.""" return _itertools.chain(self.groups.items(), self.datasets.items(), self.attributes.items(), self.links.items()) def keys(self): """Like dict.keys, but iterates over keys in groups, datasets, attributes, and links sub-dictionaries.""" return _itertools.chain(self.groups.keys(), self.datasets.keys(), self.attributes.keys(), self.links.keys()) def values(self): """Like dict.values, but iterates over values in groups, datasets, attributes, and links sub-dictionaries.""" return _itertools.chain(self.groups.values(), self.datasets.values(), self.attributes.values(), self.links.values()) class DatasetBuilder(BaseBuilder): OBJECT_REF_TYPE = 'object' REGION_REF_TYPE = 'region' @docval({'name': 'name', 'type': str, 'doc': 'The name of the dataset.'}, {'name': 'data', 'type': ('array_data', 'scalar_data', 'data', 'DatasetBuilder', 'RegionBuilder', Iterable, datetime), 'doc': 'The data in this dataset.', 'default': None}, {'name': 'dtype', 'type': (type, np.dtype, str, list), 'doc': 'The datatype of this dataset.', 'default': None}, {'name': 'attributes', 'type': dict, 'doc': 'A dictionary of attributes to create in this dataset.', 'default': dict()}, {'name': 'maxshape', 'type': (int, tuple), 'doc': 'The shape of this dataset. Use None for scalars.', 'default': None}, {'name': 'chunks', 'type': bool, 'doc': 'Whether or not to chunk this dataset.', 'default': False}, {'name': 'parent', 'type': GroupBuilder, 'doc': 'The parent builder of this builder.', 'default': None}, {'name': 'source', 'type': str, 'doc': 'The source of the data in this builder.', 'default': None}) def __init__(self, **kwargs): """ Create a Builder object for a dataset """ name, data, dtype, attributes, maxshape, chunks, parent, source = getargs( 'name', 'data', 'dtype', 'attributes', 'maxshape', 'chunks', 'parent', 'source', kwargs) super().__init__(name, attributes, parent, source) self['data'] = data self['attributes'] = _copy.copy(attributes) self.__chunks = chunks self.__maxshape = maxshape if isinstance(data, BaseBuilder): if dtype is None: dtype = self.OBJECT_REF_TYPE self.__dtype = dtype self.__name = name @property def data(self): """The data stored in the dataset represented by this builder.""" return self['data'] @data.setter def data(self, val): if self['data'] is not None: raise AttributeError("Cannot overwrite data.") self['data'] = val @property def chunks(self): """Whether or not this dataset is chunked.""" return self.__chunks @property def maxshape(self): """The max shape of this dataset.""" return self.__maxshape @property def dtype(self): """The data type of this dataset.""" return self.__dtype @dtype.setter def dtype(self, val): if self.__dtype is not None: raise AttributeError("Cannot overwrite dtype.") self.__dtype = val class LinkBuilder(Builder): @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), 'doc': 'The target group or dataset of this link.'}, {'name': 'name', 'type': str, 'doc': 'The name of the link', 'default': None}, {'name': 'parent', 'type': GroupBuilder, 'doc': 'The parent builder of this builder', 'default': None}, {'name': 'source', 'type': str, 'doc': 'The source of the data in this builder', 'default': None}) def __init__(self, **kwargs): """Create a builder object for a link.""" name, builder, parent, source = getargs('name', 'builder', 'parent', 'source', kwargs) if name is None: name = builder.name super().__init__(name, parent, source) self['builder'] = builder @property def builder(self): """The target builder object.""" return self['builder'] class ReferenceBuilder(dict): @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), 'doc': 'The group or dataset this reference applies to.'}) def __init__(self, **kwargs): """Create a builder object for a reference.""" builder = getargs('builder', kwargs) self['builder'] = builder @property def builder(self): """The target builder object.""" return self['builder'] class RegionBuilder(ReferenceBuilder): @docval({'name': 'region', 'type': (slice, tuple, list, RegionReference), 'doc': 'The region, i.e. slice or indices, into the target dataset.'}, {'name': 'builder', 'type': DatasetBuilder, 'doc': 'The dataset this region reference applies to.'}) def __init__(self, **kwargs): """Create a builder object for a region reference.""" region, builder = getargs('region', 'builder', kwargs) super().__init__(builder) self['region'] = region @property def region(self): """The selected region of the target dataset.""" return self['region'] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/build/classgenerator.py0000644000655200065520000004171600000000000020764 0ustar00circlecicirclecifrom copy import deepcopy from datetime import datetime import numpy as np from ..container import Container, Data, DataRegion, MultiContainerInterface from ..spec import AttributeSpec, LinkSpec, RefSpec, GroupSpec from ..spec.spec import BaseStorageSpec, ZERO_OR_MANY, ONE_OR_MANY from ..utils import docval, getargs, ExtenderMeta, get_docval, fmt_docval_args class ClassGenerator: def __init__(self): self.__custom_generators = [] @property def custom_generators(self): return self.__custom_generators @docval({'name': 'generator', 'type': type, 'doc': 'the CustomClassGenerator class to register'}) def register_generator(self, **kwargs): """Add a custom class generator to this ClassGenerator. Generators added later are run first. Duplicates are moved to the top of the list. """ generator = getargs('generator', kwargs) if not issubclass(generator, CustomClassGenerator): raise ValueError('Generator %s must be a subclass of CustomClassGenerator.' % generator) if generator in self.__custom_generators: self.__custom_generators.remove(generator) self.__custom_generators.insert(0, generator) @docval({'name': 'data_type', 'type': str, 'doc': 'the data type to create a AbstractContainer class for'}, {'name': 'spec', 'type': BaseStorageSpec, 'doc': ''}, {'name': 'parent_cls', 'type': type, 'doc': ''}, {'name': 'attr_names', 'type': dict, 'doc': ''}, {'name': 'type_map', 'type': 'TypeMap', 'doc': ''}, returns='the class for the given namespace and data_type', rtype=type) def generate_class(self, **kwargs): """Get the container class from data type specification. If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically created and returned. """ data_type, spec, parent_cls, attr_names, type_map = getargs('data_type', 'spec', 'parent_cls', 'attr_names', 'type_map', kwargs) not_inherited_fields = dict() for k, field_spec in attr_names.items(): if k == 'help': # pragma: no cover # (legacy) do not add field named 'help' to any part of class object continue if isinstance(field_spec, GroupSpec) and field_spec.data_type is None: # skip named, untyped groups continue if not spec.is_inherited_spec(field_spec): not_inherited_fields[k] = field_spec try: classdict = dict() bases = [parent_cls] docval_args = list(deepcopy(get_docval(parent_cls.__init__))) for attr_name, field_spec in not_inherited_fields.items(): for class_generator in self.__custom_generators: # pragma: no branch # each generator can update classdict and docval_args if class_generator.apply_generator_to_field(field_spec, bases, type_map): class_generator.process_field_spec(classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec) break # each field_spec should be processed by only one generator for class_generator in self.__custom_generators: class_generator.post_process(classdict, bases, docval_args, spec) for class_generator in reversed(self.__custom_generators): # go in reverse order so that base init is added first and # later class generators can modify or overwrite __init__ set by an earlier class generator class_generator.set_init(classdict, bases, docval_args, not_inherited_fields, spec.name) except TypeDoesNotExistError as e: # pragma: no cover # this error should never happen after hdmf#322 name = spec.data_type_def if name is None: name = 'Unknown' raise ValueError("Cannot dynamically generate class for type '%s'. " % name + str(e) + " Please define that type before defining '%s'." % name) cls = ExtenderMeta(data_type, tuple(bases), classdict) return cls class TypeDoesNotExistError(Exception): # pragma: no cover pass class CustomClassGenerator: """Subclass this class and register an instance to alter how classes are auto-generated.""" def __new__(cls, *args, **kwargs): # pragma: no cover raise TypeError('Cannot instantiate class %s' % cls.__name__) # mapping from spec types to allowable python types for docval for fields during dynamic class generation # e.g., if a dataset/attribute spec has dtype int32, then get_class should generate a docval for the class' # __init__ method that allows the types (int, np.int32, np.int64) for the corresponding field. # passing an np.int16 would raise a docval error. # passing an int64 to __init__ would result in the field storing the value as an int64 (and subsequently written # as an int64). no upconversion or downconversion happens as a result of this map _spec_dtype_map = { 'float32': (float, np.float32, np.float64), 'float': (float, np.float32, np.float64), 'float64': (float, np.float64), 'double': (float, np.float64), 'int8': (np.int8, np.int16, np.int32, np.int64, int), 'int16': (np.int16, np.int32, np.int64, int), 'short': (np.int16, np.int32, np.int64, int), 'int32': (int, np.int32, np.int64), 'int': (int, np.int32, np.int64), 'int64': np.int64, 'long': np.int64, 'uint8': (np.uint8, np.uint16, np.uint32, np.uint64), 'uint16': (np.uint16, np.uint32, np.uint64), 'uint32': (np.uint32, np.uint64), 'uint64': np.uint64, 'numeric': (float, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, int, np.uint8, np.uint16, np.uint32, np.uint64), 'text': str, 'utf': str, 'utf8': str, 'utf-8': str, 'ascii': bytes, 'bytes': bytes, 'bool': (bool, np.bool_), 'isodatetime': datetime, 'datetime': datetime } @classmethod def _get_type_from_spec_dtype(cls, spec_dtype): """Get the Python type associated with the given spec dtype string. Raises ValueError if the given dtype has no mapping to a Python type. """ dtype = cls._spec_dtype_map.get(spec_dtype) if dtype is None: # pragma: no cover # this should not happen as long as _spec_dtype_map is kept up to date with # hdmf.spec.spec.DtypeHelper.valid_primary_dtypes raise ValueError("Spec dtype '%s' cannot be mapped to a Python type." % spec_dtype) return dtype @classmethod def _get_container_type(cls, type_name, type_map): """Search all namespaces for the container class associated with the given data type. Raises TypeDoesNotExistError if type is not found in any namespace. """ container_type = type_map.get_dt_container_cls(type_name) if container_type is None: # pragma: no cover # this should never happen after hdmf#322 raise TypeDoesNotExistError("Type '%s' does not exist." % type_name) return container_type @classmethod def _get_type(cls, spec, type_map): """Get the type of a spec for use in docval. Returns a container class, a type, a tuple of types, ('array_data', 'data') for specs with non-scalar shape, or (Data, Container) when an attribute reference target has not been mapped to a container class. """ if isinstance(spec, AttributeSpec): if isinstance(spec.dtype, RefSpec): try: container_type = cls._get_container_type(spec.dtype.target_type, type_map) return container_type except TypeDoesNotExistError: # TODO what happens when the attribute ref target is not (or not yet) mapped to a container class? # returning Data, Container works as a generic fallback for now but should be more specific return Data, Container elif spec.shape is None and spec.dims is None: return cls._get_type_from_spec_dtype(spec.dtype) else: return 'array_data', 'data' if isinstance(spec, LinkSpec): return cls._get_container_type(spec.target_type, type_map) if spec.data_type is not None: return cls._get_container_type(spec.data_type, type_map) if spec.shape is None and spec.dims is None: return cls._get_type_from_spec_dtype(spec.dtype) return 'array_data', 'data' @classmethod def _ischild(cls, dtype): """Check if dtype represents a type that is a child.""" ret = False if isinstance(dtype, tuple): for sub in dtype: ret = ret or cls._ischild(sub) elif isinstance(dtype, type) and issubclass(dtype, (Container, Data, DataRegion)): ret = True return ret @staticmethod def _set_default_name(docval_args, default_name): """Set the default value for the name docval argument.""" if default_name is not None: for x in docval_args: if x['name'] == 'name': x['default'] = default_name @classmethod def apply_generator_to_field(cls, field_spec, bases, type_map): """Return True to signal that this generator should return on all fields not yet processed.""" return True @classmethod def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): """Add __fields__ to the classdict and update the docval args for the field spec with the given attribute name. :param classdict: The dict to update with __fields__ (or a different parent_cls._fieldsname). :param docval_args: The list of docval arguments. :param parent_cls: The parent class. :param attr_name: The attribute name of the field spec for the container class to generate. :param not_inherited_fields: Dictionary of fields not inherited from the parent class. :param type_map: The type map to use. :param spec: The spec for the container class to generate. """ field_spec = not_inherited_fields[attr_name] dtype = cls._get_type(field_spec, type_map) fields_conf = {'name': attr_name, 'doc': field_spec['doc']} if cls._ischild(dtype) and issubclass(parent_cls, Container) and not isinstance(field_spec, LinkSpec): fields_conf['child'] = True # if getattr(field_spec, 'value', None) is not None: # TODO set the fixed value on the class? # fields_conf['settable'] = False classdict.setdefault(parent_cls._fieldsname, list()).append(fields_conf) docval_arg = dict( name=attr_name, doc=field_spec.doc, type=cls._get_type(field_spec, type_map) ) shape = getattr(field_spec, 'shape', None) if shape is not None: docval_arg['shape'] = shape if cls._check_spec_optional(field_spec, spec): docval_arg['default'] = getattr(field_spec, 'default_value', None) cls._add_to_docval_args(docval_args, docval_arg) @classmethod def _check_spec_optional(cls, field_spec, spec): """Returns True if the spec or any of its parents (up to the parent type spec) are optional.""" if not field_spec.required: return True if field_spec == spec: return False if field_spec.parent is not None: return cls._check_spec_optional(field_spec.parent, spec) @classmethod def _add_to_docval_args(cls, docval_args, arg): """Add the docval arg to the list if not present. If present, overwrite it in place.""" inserted = False for i, x in enumerate(docval_args): if x['name'] == arg['name']: docval_args[i] = arg inserted = True if not inserted: docval_args.append(arg) @classmethod def post_process(cls, classdict, bases, docval_args, spec): """Convert classdict['__fields__'] to tuple and update docval args for a fixed name and default name. :param classdict: The class dictionary to convert with '__fields__' key (or a different bases[0]._fieldsname) :param bases: The list of base classes. :param docval_args: The dict of docval arguments. :param spec: The spec for the container class to generate. """ # convert classdict['__fields__'] from list to tuple if present for b in bases: fields = classdict.get(b._fieldsname) if fields is not None and not isinstance(fields, tuple): classdict[b._fieldsname] = tuple(fields) # if spec provides a fixed name for this type, remove the 'name' arg from docval_args so that values cannot # be passed for a name positional or keyword arg if spec.name is not None: for arg in list(docval_args): if arg['name'] == 'name': docval_args.remove(arg) # set default name in docval args if provided cls._set_default_name(docval_args, spec.default_name) @classmethod def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): # get docval arg names from superclass base = bases[0] parent_docval_args = set(arg['name'] for arg in get_docval(base.__init__)) new_args = list() for attr_name, field_spec in not_inherited_fields.items(): # auto-initialize arguments not found in superclass if attr_name not in parent_docval_args: new_args.append(attr_name) @docval(*docval_args) def __init__(self, **kwargs): if name is not None: # force container name to be the fixed name in the spec kwargs.update(name=name) pargs, pkwargs = fmt_docval_args(base.__init__, kwargs) base.__init__(self, *pargs, **pkwargs) # special case: need to pass self to __init__ for f in new_args: arg_val = kwargs.get(f, None) setattr(self, f, arg_val) classdict['__init__'] = __init__ class MCIClassGenerator(CustomClassGenerator): @classmethod def apply_generator_to_field(cls, field_spec, bases, type_map): """Return True if the field spec has quantity * or +, False otherwise.""" return getattr(field_spec, 'quantity', None) in (ZERO_OR_MANY, ONE_OR_MANY) @classmethod def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): """Add __clsconf__ to the classdict and update the docval args for the field spec with the given attribute name. :param classdict: The dict to update with __clsconf__. :param docval_args: The list of docval arguments. :param parent_cls: The parent class. :param attr_name: The attribute name of the field spec for the container class to generate. :param not_inherited_fields: Dictionary of fields not inherited from the parent class. :param type_map: The type map to use. :param spec: The spec for the container class to generate. """ field_spec = not_inherited_fields[attr_name] field_clsconf = dict( attr=attr_name, type=cls._get_type(field_spec, type_map), add='add_{}'.format(attr_name), get='get_{}'.format(attr_name), create='create_{}'.format(attr_name) ) classdict.setdefault('__clsconf__', list()).append(field_clsconf) # add a specialized docval arg for __init__ docval_arg = dict( name=attr_name, doc=field_spec.doc, type=(list, tuple, dict, cls._get_type(field_spec, type_map)) ) if cls._check_spec_optional(field_spec, spec): docval_arg['default'] = getattr(field_spec, 'default_value', None) cls._add_to_docval_args(docval_args, docval_arg) @classmethod def post_process(cls, classdict, bases, docval_args, spec): """Add MultiContainerInterface to the list of base classes. :param classdict: The class dictionary. :param bases: The list of base classes. :param docval_args: The dict of docval arguments. :param spec: The spec for the container class to generate. """ if '__clsconf__' in classdict: # do not add MCI as a base if a base is already a subclass of MultiContainerInterface for b in bases: if issubclass(b, MultiContainerInterface): break else: bases.insert(0, MultiContainerInterface) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/build/errors.py0000644000655200065520000000416700000000000017263 0ustar00circlecicirclecifrom .builders import Builder from ..container import AbstractContainer from ..utils import docval, getargs class BuildError(Exception): """Error raised when building a container into a builder.""" @docval({'name': 'builder', 'type': Builder, 'doc': 'the builder that cannot be built'}, {'name': 'reason', 'type': str, 'doc': 'the reason for the error'}) def __init__(self, **kwargs): self.__builder = getargs('builder', kwargs) self.__reason = getargs('reason', kwargs) self.__message = "%s (%s): %s" % (self.__builder.name, self.__builder.path, self.__reason) super().__init__(self.__message) class OrphanContainerBuildError(BuildError): @docval({'name': 'builder', 'type': Builder, 'doc': 'the builder containing the broken link'}, {'name': 'container', 'type': AbstractContainer, 'doc': 'the container that has no parent'}) def __init__(self, **kwargs): builder = getargs('builder', kwargs) self.__container = getargs('container', kwargs) reason = ("Linked %s '%s' has no parent. Remove the link or ensure the linked container is added properly." % (self.__container.__class__.__name__, self.__container.name)) super().__init__(builder=builder, reason=reason) class ReferenceTargetNotBuiltError(BuildError): @docval({'name': 'builder', 'type': Builder, 'doc': 'the builder containing the reference that cannot be found'}, {'name': 'container', 'type': AbstractContainer, 'doc': 'the container that is not built yet'}) def __init__(self, **kwargs): builder = getargs('builder', kwargs) self.__container = getargs('container', kwargs) reason = ("Could not find already-built Builder for %s '%s' in BuildManager" % (self.__container.__class__.__name__, self.__container.name)) super().__init__(builder=builder, reason=reason) class ContainerConfigurationError(Exception): """Error raised when the container class is improperly configured.""" pass class ConstructError(Exception): """Error raised when constructing a container from a builder.""" ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/build/manager.py0000644000655200065520000011543700000000000017364 0ustar00circlecicircleciimport logging from collections import OrderedDict, deque from copy import copy from .builders import DatasetBuilder, GroupBuilder, LinkBuilder, Builder, BaseBuilder from .classgenerator import ClassGenerator, CustomClassGenerator, MCIClassGenerator from ..container import AbstractContainer, Container, Data from ..spec import DatasetSpec, GroupSpec, NamespaceCatalog, SpecReader from ..spec.spec import BaseStorageSpec from ..utils import docval, getargs, call_docval_func, ExtenderMeta class Proxy: """ A temporary object to represent a Container. This gets used when resolving the true location of a Container's parent. Proxy objects allow simple bookkeeping of all potential parents a Container may have. This object is used by providing all the necessary information for describing the object. This object gets passed around and candidates are accumulated. Upon calling resolve, all saved candidates are matched against the information (provided to the constructor). The candidate that has an exact match is returned. """ def __init__(self, manager, source, location, namespace, data_type): self.__source = source self.__location = location self.__namespace = namespace self.__data_type = data_type self.__manager = manager self.__candidates = list() @property def source(self): """The source of the object e.g. file source""" return self.__source @property def location(self): """The location of the object. This can be thought of as a unique path""" return self.__location @property def namespace(self): """The namespace from which the data_type of this Proxy came from""" return self.__namespace @property def data_type(self): """The data_type of Container that should match this Proxy""" return self.__data_type @docval({"name": "object", "type": (BaseBuilder, Container), "doc": "the container or builder to get a proxy for"}) def matches(self, **kwargs): obj = getargs('object', kwargs) if not isinstance(obj, Proxy): obj = self.__manager.get_proxy(obj) return self == obj @docval({"name": "container", "type": Container, "doc": "the Container to add as a candidate match"}) def add_candidate(self, **kwargs): container = getargs('container', kwargs) self.__candidates.append(container) def resolve(self): for candidate in self.__candidates: if self.matches(candidate): return candidate raise ValueError("No matching candidate Container found for " + self) def __eq__(self, other): return self.data_type == other.data_type and \ self.location == other.location and \ self.namespace == other.namespace and \ self.source == other.source def __repr__(self): ret = dict() for key in ('source', 'location', 'namespace', 'data_type'): ret[key] = getattr(self, key, None) return str(ret) class BuildManager: """ A class for managing builds of AbstractContainers """ def __init__(self, type_map): self.logger = logging.getLogger('%s.%s' % (self.__class__.__module__, self.__class__.__qualname__)) self.__builders = dict() self.__containers = dict() self.__active_builders = set() self.__type_map = type_map self.__ref_queue = deque() # a queue of the ReferenceBuilders that need to be added @property def namespace_catalog(self): return self.__type_map.namespace_catalog @property def type_map(self): return self.__type_map @docval({"name": "object", "type": (BaseBuilder, AbstractContainer), "doc": "the container or builder to get a proxy for"}, {"name": "source", "type": str, "doc": "the source of container being built i.e. file path", 'default': None}) def get_proxy(self, **kwargs): obj = getargs('object', kwargs) if isinstance(obj, BaseBuilder): return self._get_proxy_builder(obj) elif isinstance(obj, AbstractContainer): return self._get_proxy_container(obj) def _get_proxy_builder(self, builder): dt = self.__type_map.get_builder_dt(builder) ns = self.__type_map.get_builder_ns(builder) stack = list() tmp = builder while tmp is not None: stack.append(tmp.name) tmp = self.__get_parent_dt_builder(tmp) loc = "/".join(reversed(stack)) return Proxy(self, builder.source, loc, ns, dt) def _get_proxy_container(self, container): ns, dt = self.__type_map.get_container_ns_dt(container) stack = list() tmp = container while tmp is not None: if isinstance(tmp, Proxy): stack.append(tmp.location) break else: stack.append(tmp.name) tmp = tmp.parent loc = "/".join(reversed(stack)) return Proxy(self, container.container_source, loc, ns, dt) @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, {"name": "source", "type": str, "doc": "the source of container being built i.e. file path", 'default': None}, {"name": "spec_ext", "type": BaseStorageSpec, "doc": "a spec that further refines the base specification", 'default': None}, {"name": "export", "type": bool, "doc": "whether this build is for exporting", 'default': False}, {"name": "root", "type": bool, "doc": "whether the container is the root of the build process", 'default': False}) def build(self, **kwargs): """ Build the GroupBuilder/DatasetBuilder for the given AbstractContainer""" container, export = getargs('container', 'export', kwargs) source, spec_ext, root = getargs('source', 'spec_ext', 'root', kwargs) result = self.get_builder(container) if root: self.__active_builders.clear() # reset active builders at start of build process if result is None: self.logger.debug("Building new %s '%s' (container_source: %s, source: %s, extended spec: %s, export: %s)" % (container.__class__.__name__, container.name, repr(container.container_source), repr(source), spec_ext is not None, export)) # the container_source is not set or checked when exporting if not export: if container.container_source is None: container.container_source = source elif source is None: source = container.container_source else: if container.container_source != source: raise ValueError("Cannot change container_source once set: '%s' %s.%s" % (container.name, container.__class__.__module__, container.__class__.__name__)) # NOTE: if exporting, then existing cached builder will be ignored and overridden with new build result result = self.__type_map.build(container, self, source=source, spec_ext=spec_ext, export=export) self.prebuilt(container, result) self.__active_prebuilt(result) self.logger.debug("Done building %s '%s'" % (container.__class__.__name__, container.name)) elif not self.__is_active_builder(result) and container.modified: # if builder was built on file read and is then modified (append mode), it needs to be rebuilt self.logger.debug("Rebuilding modified %s '%s' (source: %s, extended spec: %s)" % (container.__class__.__name__, container.name, repr(source), spec_ext is not None)) result = self.__type_map.build(container, self, builder=result, source=source, spec_ext=spec_ext, export=export) self.logger.debug("Done rebuilding %s '%s'" % (container.__class__.__name__, container.name)) else: self.logger.debug("Using prebuilt %s '%s' for %s '%s'" % (result.__class__.__name__, result.name, container.__class__.__name__, container.name)) if root: # create reference builders only after building all other builders self.__add_refs() self.__active_builders.clear() # reset active builders now that build process has completed return result @docval({"name": "container", "type": AbstractContainer, "doc": "the AbstractContainer to save as prebuilt"}, {'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), 'doc': 'the Builder representation of the given container'}) def prebuilt(self, **kwargs): ''' Save the Builder for a given AbstractContainer for future use ''' container, builder = getargs('container', 'builder', kwargs) container_id = self.__conthash__(container) self.__builders[container_id] = builder builder_id = self.__bldrhash__(builder) self.__containers[builder_id] = container def __active_prebuilt(self, builder): """Save the Builder for future use during the active/current build process.""" builder_id = self.__bldrhash__(builder) self.__active_builders.add(builder_id) def __is_active_builder(self, builder): """Return True if the Builder was created during the active/current build process.""" builder_id = self.__bldrhash__(builder) return builder_id in self.__active_builders def __conthash__(self, obj): return id(obj) def __bldrhash__(self, obj): return id(obj) def __add_refs(self): ''' Add ReferenceBuilders. References get queued to be added after all other objects are built. This is because the current traversal algorithm (i.e. iterating over specs) does not happen in a guaranteed order. We need to build the targets of the reference builders so that the targets have the proper parent, and then write the reference builders after we write everything else. ''' while len(self.__ref_queue) > 0: call = self.__ref_queue.popleft() self.logger.debug("Adding ReferenceBuilder with call id %d from queue (length %d)" % (id(call), len(self.__ref_queue))) call() def queue_ref(self, func): '''Set aside creating ReferenceBuilders''' # TODO: come up with more intelligent way of # queueing reference resolution, based on reference # dependency self.__ref_queue.append(func) def purge_outdated(self): containers_copy = self.__containers.copy() for container in containers_copy.values(): if container.modified: container_id = self.__conthash__(container) builder = self.__builders.get(container_id) builder_id = self.__bldrhash__(builder) self.logger.debug("Purging %s '%s' for %s '%s' from prebuilt cache" % (builder.__class__.__name__, builder.name, container.__class__.__name__, container.name)) self.__builders.pop(container_id) self.__containers.pop(builder_id) @docval({"name": "container", "type": AbstractContainer, "doc": "the container to get the builder for"}) def get_builder(self, **kwargs): """Return the prebuilt builder for the given container or None if it does not exist.""" container = getargs('container', kwargs) container_id = self.__conthash__(container) result = self.__builders.get(container_id) return result @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), 'doc': 'the builder to construct the AbstractContainer from'}) def construct(self, **kwargs): """ Construct the AbstractContainer represented by the given builder """ builder = getargs('builder', kwargs) if isinstance(builder, LinkBuilder): builder = builder.target builder_id = self.__bldrhash__(builder) result = self.__containers.get(builder_id) if result is None: parent_builder = self.__get_parent_dt_builder(builder) if parent_builder is not None: parent = self._get_proxy_builder(parent_builder) result = self.__type_map.construct(builder, self, parent) else: # we are at the top of the hierarchy, # so it must be time to resolve parents result = self.__type_map.construct(builder, self, None) self.__resolve_parents(result) self.prebuilt(result, builder) result.set_modified(False) return result def __resolve_parents(self, container): stack = [container] while len(stack) > 0: tmp = stack.pop() if isinstance(tmp.parent, Proxy): tmp.parent = tmp.parent.resolve() for child in tmp.children: stack.append(child) def __get_parent_dt_builder(self, builder): ''' Get the next builder above the given builder that has a data_type ''' tmp = builder.parent ret = None while tmp is not None: ret = tmp dt = self.__type_map.get_builder_dt(tmp) if dt is not None: break tmp = tmp.parent return ret # *** The following methods just delegate calls to self.__type_map *** @docval({'name': 'builder', 'type': Builder, 'doc': 'the Builder to get the class object for'}) def get_cls(self, **kwargs): ''' Get the class object for the given Builder ''' builder = getargs('builder', kwargs) return self.__type_map.get_cls(builder) @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, returns='The name a Builder should be given when building this container', rtype=str) def get_builder_name(self, **kwargs): ''' Get the name a Builder should be given ''' container = getargs('container', kwargs) return self.__type_map.get_builder_name(container) @docval({'name': 'spec', 'type': (DatasetSpec, GroupSpec), 'doc': 'the parent spec to search'}, {'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), 'doc': 'the builder to get the sub-specification for'}) def get_subspec(self, **kwargs): ''' Get the specification from this spec that corresponds to the given builder ''' spec, builder = getargs('spec', 'builder', kwargs) return self.__type_map.get_subspec(spec, builder) @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), 'doc': 'the builder to get the sub-specification for'}) def get_builder_ns(self, **kwargs): ''' Get the namespace of a builder ''' builder = getargs('builder', kwargs) return self.__type_map.get_builder_ns(builder) @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), 'doc': 'the builder to get the data_type for'}) def get_builder_dt(self, **kwargs): ''' Get the data_type of a builder ''' builder = getargs('builder', kwargs) return self.__type_map.get_builder_dt(builder) @docval({'name': 'builder', 'type': (GroupBuilder, DatasetBuilder, AbstractContainer), 'doc': 'the builder or container to check'}, {'name': 'parent_data_type', 'type': str, 'doc': 'the potential parent data_type that refers to a data_type'}, returns="True if data_type of *builder* is a sub-data_type of *parent_data_type*, False otherwise", rtype=bool) def is_sub_data_type(self, **kwargs): ''' Return whether or not data_type of *builder* is a sub-data_type of *parent_data_type* ''' builder, parent_dt = getargs('builder', 'parent_data_type', kwargs) if isinstance(builder, (GroupBuilder, DatasetBuilder)): ns = self.get_builder_ns(builder) dt = self.get_builder_dt(builder) else: # builder is an AbstractContainer ns, dt = self.type_map.get_container_ns_dt(builder) return self.namespace_catalog.is_sub_data_type(ns, dt, parent_dt) class TypeSource: '''A class to indicate the source of a data_type in a namespace. This class should only be used by TypeMap ''' @docval({"name": "namespace", "type": str, "doc": "the namespace the from, which the data_type originated"}, {"name": "data_type", "type": str, "doc": "the name of the type"}) def __init__(self, **kwargs): namespace, data_type = getargs('namespace', 'data_type', kwargs) self.__namespace = namespace self.__data_type = data_type @property def namespace(self): return self.__namespace @property def data_type(self): return self.__data_type class TypeMap: ''' A class to maintain the map between ObjectMappers and AbstractContainer classes ''' @docval({'name': 'namespaces', 'type': NamespaceCatalog, 'doc': 'the NamespaceCatalog to use', 'default': None}, {'name': 'mapper_cls', 'type': type, 'doc': 'the ObjectMapper class to use', 'default': None}) def __init__(self, **kwargs): namespaces, mapper_cls = getargs('namespaces', 'mapper_cls', kwargs) if namespaces is None: namespaces = NamespaceCatalog() if mapper_cls is None: from .objectmapper import ObjectMapper # avoid circular import mapper_cls = ObjectMapper self.__ns_catalog = namespaces self.__mappers = dict() # already constructed ObjectMapper classes self.__mapper_cls = dict() # the ObjectMapper class to use for each container type self.__container_types = OrderedDict() self.__data_types = dict() self.__default_mapper_cls = mapper_cls self.__class_generator = ClassGenerator() self.register_generator(CustomClassGenerator) self.register_generator(MCIClassGenerator) @property def namespace_catalog(self): return self.__ns_catalog @property def container_types(self): return self.__container_types def __copy__(self): ret = TypeMap(copy(self.__ns_catalog), self.__default_mapper_cls) ret.merge(self) return ret def __deepcopy__(self, memo): # XXX: From @nicain: All of a sudden legacy tests started # needing this argument in deepcopy. Doesn't hurt anything, though. return self.__copy__() def copy_mappers(self, type_map): for namespace in self.__ns_catalog.namespaces: if namespace not in type_map.__container_types: continue for data_type in self.__ns_catalog.get_namespace(namespace).get_registered_types(): container_cls = type_map.__container_types[namespace].get(data_type) if container_cls is None: continue self.register_container_type(namespace, data_type, container_cls) if container_cls in type_map.__mapper_cls: self.register_map(container_cls, type_map.__mapper_cls[container_cls]) def merge(self, type_map, ns_catalog=False): if ns_catalog: self.namespace_catalog.merge(type_map.namespace_catalog) for namespace in type_map.__container_types: for data_type in type_map.__container_types[namespace]: container_cls = type_map.__container_types[namespace][data_type] self.register_container_type(namespace, data_type, container_cls) for container_cls in type_map.__mapper_cls: self.register_map(container_cls, type_map.__mapper_cls[container_cls]) for custom_generators in reversed(type_map.__class_generator.custom_generators): # iterate in reverse order because generators are stored internally as a stack self.register_generator(custom_generators) @docval({"name": "generator", "type": type, "doc": "the CustomClassGenerator class to register"}) def register_generator(self, **kwargs): """Add a custom class generator.""" generator = getargs('generator', kwargs) self.__class_generator.register_generator(generator) @docval({'name': 'namespace_path', 'type': str, 'doc': 'the path to the file containing the namespaces(s) to load'}, {'name': 'resolve', 'type': bool, 'doc': 'whether or not to include objects from included/parent spec objects', 'default': True}, {'name': 'reader', 'type': SpecReader, 'doc': 'the class to user for reading specifications', 'default': None}, returns="the namespaces loaded from the given file", rtype=dict) def load_namespaces(self, **kwargs): '''Load namespaces from a namespace file. This method will call load_namespaces on the NamespaceCatalog used to construct this TypeMap. Additionally, it will process the return value to keep track of what types were included in the loaded namespaces. Calling load_namespaces here has the advantage of being able to keep track of type dependencies across namespaces. ''' deps = call_docval_func(self.__ns_catalog.load_namespaces, kwargs) for new_ns, ns_deps in deps.items(): for src_ns, types in ns_deps.items(): for dt in types: container_cls = self.get_dt_container_cls(dt, src_ns, autogen=False) if container_cls is None: container_cls = TypeSource(src_ns, dt) self.register_container_type(new_ns, dt, container_cls) return deps @docval({"name": "namespace", "type": str, "doc": "the namespace containing the data_type"}, {"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"}, {"name": "autogen", "type": bool, "doc": "autogenerate class if one does not exist", "default": True}, returns='the class for the given namespace and data_type', rtype=type) def get_container_cls(self, **kwargs): """Get the container class from data type specification. If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically created and returned. """ # NOTE: this internally used function get_container_cls will be removed in favor of get_dt_container_cls namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs) return self.get_dt_container_cls(data_type, namespace, autogen) @docval({"name": "data_type", "type": str, "doc": "the data type to create a AbstractContainer class for"}, {"name": "namespace", "type": str, "doc": "the namespace containing the data_type", "default": None}, {"name": "autogen", "type": bool, "doc": "autogenerate class if one does not exist", "default": True}, returns='the class for the given namespace and data_type', rtype=type) def get_dt_container_cls(self, **kwargs): """Get the container class from data type specification. If no class has been associated with the ``data_type`` from ``namespace``, a class will be dynamically created and returned. Replaces get_container_cls but namespace is optional. If namespace is unknown, it will be looked up from all namespaces. """ namespace, data_type, autogen = getargs('namespace', 'data_type', 'autogen', kwargs) # namespace is unknown, so look it up if namespace is None: for ns_key, ns_data_types in self.__container_types.items(): # NOTE that the type_name may appear in multiple namespaces based on how they were resolved # but the same type_name should point to the same class if data_type in ns_data_types: namespace = ns_key break cls = self.__get_container_cls(namespace, data_type) if cls is None and autogen: # dynamically generate a class spec = self.__ns_catalog.get_spec(namespace, data_type) self.__check_dependent_types(spec, namespace) parent_cls = self.__get_parent_cls(namespace, data_type, spec) attr_names = self.__default_mapper_cls.get_attr_names(spec) cls = self.__class_generator.generate_class(data_type, spec, parent_cls, attr_names, self) self.register_container_type(namespace, data_type, cls) return cls def __check_dependent_types(self, spec, namespace): """Ensure that classes for all types used by this type exist in this namespace and generate them if not. """ def __check_dependent_types_helper(spec, namespace): if isinstance(spec, (GroupSpec, DatasetSpec)): if spec.data_type_inc is not None: self.get_dt_container_cls(spec.data_type_inc, namespace) # TODO handle recursive definitions if spec.data_type_def is not None: # nested type definition self.get_dt_container_cls(spec.data_type_def, namespace) else: # spec is a LinkSpec self.get_dt_container_cls(spec.target_type, namespace) if isinstance(spec, GroupSpec): for child_spec in (spec.groups + spec.datasets + spec.links): __check_dependent_types_helper(child_spec, namespace) if spec.data_type_inc is not None: self.get_dt_container_cls(spec.data_type_inc, namespace) if isinstance(spec, GroupSpec): for child_spec in (spec.groups + spec.datasets + spec.links): __check_dependent_types_helper(child_spec, namespace) def __get_parent_cls(self, namespace, data_type, spec): dt_hier = self.__ns_catalog.get_hierarchy(namespace, data_type) dt_hier = dt_hier[1:] # remove the current data_type parent_cls = None for t in dt_hier: parent_cls = self.__get_container_cls(namespace, t) if parent_cls is not None: break if parent_cls is None: if isinstance(spec, GroupSpec): parent_cls = Container elif isinstance(spec, DatasetSpec): parent_cls = Data else: raise ValueError("Cannot generate class from %s" % type(spec)) if type(parent_cls) is not ExtenderMeta: raise ValueError("parent class %s is not of type ExtenderMeta - %s" % (parent_cls, type(parent_cls))) return parent_cls def __get_container_cls(self, namespace, data_type): """Get the container class for the namespace, data_type. If the class doesn't exist yet, generate it.""" if namespace not in self.__container_types: return None if data_type not in self.__container_types[namespace]: return None ret = self.__container_types[namespace][data_type] if isinstance(ret, TypeSource): # data_type is a dependency from ret.namespace cls = self.get_dt_container_cls(ret.data_type, ret.namespace) # get class / generate class # register the same class into this namespace (replaces TypeSource) self.register_container_type(namespace, data_type, cls) ret = cls return ret @docval({'name': 'obj', 'type': (GroupBuilder, DatasetBuilder, LinkBuilder, GroupSpec, DatasetSpec), 'doc': 'the object to get the type key for'}) def __type_key(self, obj): """ A wrapper function to simplify the process of getting a type_key for an object. The type_key is used to get the data_type from a Builder's attributes. """ if isinstance(obj, LinkBuilder): obj = obj.builder if isinstance(obj, (GroupBuilder, GroupSpec)): return self.__ns_catalog.group_spec_cls.type_key() else: return self.__ns_catalog.dataset_spec_cls.type_key() @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), 'doc': 'the builder to get the data_type for'}) def get_builder_dt(self, **kwargs): ''' Get the data_type of a builder ''' builder = getargs('builder', kwargs) ret = None if isinstance(builder, LinkBuilder): builder = builder.builder if isinstance(builder, GroupBuilder): ret = builder.attributes.get(self.__ns_catalog.group_spec_cls.type_key()) else: ret = builder.attributes.get(self.__ns_catalog.dataset_spec_cls.type_key()) if isinstance(ret, bytes): ret = ret.decode('UTF-8') return ret @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), 'doc': 'the builder to get the sub-specification for'}) def get_builder_ns(self, **kwargs): ''' Get the namespace of a builder ''' builder = getargs('builder', kwargs) if isinstance(builder, LinkBuilder): builder = builder.builder ret = builder.attributes.get('namespace') return ret @docval({'name': 'builder', 'type': Builder, 'doc': 'the Builder object to get the corresponding AbstractContainer class for'}) def get_cls(self, **kwargs): ''' Get the class object for the given Builder ''' builder = getargs('builder', kwargs) data_type = self.get_builder_dt(builder) if data_type is None: raise ValueError("No data_type found for builder %s" % builder.path) namespace = self.get_builder_ns(builder) if namespace is None: raise ValueError("No namespace found for builder %s" % builder.path) return self.get_dt_container_cls(data_type, namespace) @docval({'name': 'spec', 'type': (DatasetSpec, GroupSpec), 'doc': 'the parent spec to search'}, {'name': 'builder', 'type': (DatasetBuilder, GroupBuilder, LinkBuilder), 'doc': 'the builder to get the sub-specification for'}) def get_subspec(self, **kwargs): ''' Get the specification from this spec that corresponds to the given builder ''' spec, builder = getargs('spec', 'builder', kwargs) if isinstance(builder, LinkBuilder): builder_type = type(builder.builder) else: builder_type = type(builder) if issubclass(builder_type, DatasetBuilder): subspec = spec.get_dataset(builder.name) else: subspec = spec.get_group(builder.name) if subspec is None: # builder was generated from something with a data_type and a wildcard name if isinstance(builder, LinkBuilder): dt = self.get_builder_dt(builder.builder) else: dt = self.get_builder_dt(builder) if dt is not None: ns = self.get_builder_ns(builder) hierarchy = self.__ns_catalog.get_hierarchy(ns, dt) for t in hierarchy: subspec = spec.get_data_type(t) if subspec is not None: break return subspec def get_container_ns_dt(self, obj): container_cls = obj.__class__ namespace, data_type = self.get_container_cls_dt(container_cls) return namespace, data_type def get_container_cls_dt(self, cls): def_ret = (None, None) for _cls in cls.__mro__: ret = self.__data_types.get(_cls, def_ret) if ret is not def_ret: return ret return ret @docval({'name': 'namespace', 'type': str, 'doc': 'the namespace to get the container classes for', 'default': None}) def get_container_classes(self, **kwargs): namespace = getargs('namespace', kwargs) ret = self.__data_types.keys() if namespace is not None: ret = filter(lambda x: self.__data_types[x][0] == namespace, ret) return list(ret) @docval({'name': 'obj', 'type': (AbstractContainer, Builder), 'doc': 'the object to get the ObjectMapper for'}, returns='the ObjectMapper to use for mapping the given object', rtype='ObjectMapper') def get_map(self, **kwargs): """ Return the ObjectMapper object that should be used for the given container """ obj = getargs('obj', kwargs) # get the container class, and namespace/data_type if isinstance(obj, AbstractContainer): container_cls = obj.__class__ namespace, data_type = self.get_container_cls_dt(container_cls) if namespace is None: raise ValueError("class %s is not mapped to a data_type" % container_cls) else: data_type = self.get_builder_dt(obj) namespace = self.get_builder_ns(obj) container_cls = self.get_cls(obj) # now build the ObjectMapper class mapper = self.__mappers.get(container_cls) if mapper is None: mapper_cls = self.__default_mapper_cls for cls in container_cls.__mro__: tmp_mapper_cls = self.__mapper_cls.get(cls) if tmp_mapper_cls is not None: mapper_cls = tmp_mapper_cls break spec = self.__ns_catalog.get_spec(namespace, data_type) mapper = mapper_cls(spec) self.__mappers[container_cls] = mapper return mapper @docval({"name": "namespace", "type": str, "doc": "the namespace containing the data_type to map the class to"}, {"name": "data_type", "type": str, "doc": "the data_type to map the class to"}, {"name": "container_cls", "type": (TypeSource, type), "doc": "the class to map to the specified data_type"}) def register_container_type(self, **kwargs): ''' Map a container class to a data_type ''' namespace, data_type, container_cls = getargs('namespace', 'data_type', 'container_cls', kwargs) spec = self.__ns_catalog.get_spec(namespace, data_type) # make sure the spec exists self.__container_types.setdefault(namespace, dict()) self.__container_types[namespace][data_type] = container_cls self.__data_types.setdefault(container_cls, (namespace, data_type)) if not isinstance(container_cls, TypeSource): setattr(container_cls, spec.type_key(), data_type) setattr(container_cls, 'namespace', namespace) @docval({"name": "container_cls", "type": type, "doc": "the AbstractContainer class for which the given ObjectMapper class gets used for"}, {"name": "mapper_cls", "type": type, "doc": "the ObjectMapper class to use to map"}) def register_map(self, **kwargs): ''' Map a container class to an ObjectMapper class ''' container_cls, mapper_cls = getargs('container_cls', 'mapper_cls', kwargs) if self.get_container_cls_dt(container_cls) == (None, None): raise ValueError('cannot register map for type %s - no data_type found' % container_cls) self.__mapper_cls[container_cls] = mapper_cls @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, {"name": "manager", "type": BuildManager, "doc": "the BuildManager to use for managing this build", 'default': None}, {"name": "source", "type": str, "doc": "the source of container being built i.e. file path", 'default': None}, {"name": "builder", "type": BaseBuilder, "doc": "the Builder to build on", 'default': None}, {"name": "spec_ext", "type": BaseStorageSpec, "doc": "a spec extension", 'default': None}, {"name": "export", "type": bool, "doc": "whether this build is for exporting", 'default': False}) def build(self, **kwargs): """Build the GroupBuilder/DatasetBuilder for the given AbstractContainer""" container, manager, builder = getargs('container', 'manager', 'builder', kwargs) source, spec_ext, export = getargs('source', 'spec_ext', 'export', kwargs) # get the ObjectMapper to map between Spec objects and AbstractContainer attributes obj_mapper = self.get_map(container) if obj_mapper is None: raise ValueError('No ObjectMapper found for container of type %s' % str(container.__class__.__name__)) # convert the container to a builder using the ObjectMapper if manager is None: manager = BuildManager(self) builder = obj_mapper.build(container, manager, builder=builder, source=source, spec_ext=spec_ext, export=export) # add additional attributes (namespace, data_type, object_id) to builder namespace, data_type = self.get_container_ns_dt(container) builder.set_attribute('namespace', namespace) builder.set_attribute(self.__type_key(obj_mapper.spec), data_type) builder.set_attribute(obj_mapper.spec.id_key(), container.object_id) return builder @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), 'doc': 'the builder to construct the AbstractContainer from'}, {'name': 'build_manager', 'type': BuildManager, 'doc': 'the BuildManager for constructing', 'default': None}, {'name': 'parent', 'type': (Proxy, Container), 'doc': 'the parent Container/Proxy for the Container being built', 'default': None}) def construct(self, **kwargs): """ Construct the AbstractContainer represented by the given builder """ builder, build_manager, parent = getargs('builder', 'build_manager', 'parent', kwargs) if build_manager is None: build_manager = BuildManager(self) obj_mapper = self.get_map(builder) if obj_mapper is None: dt = builder.attributes[self.namespace_catalog.group_spec_cls.type_key()] raise ValueError('No ObjectMapper found for builder of type %s' % dt) else: return obj_mapper.construct(builder, build_manager, parent) @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, returns='The name a Builder should be given when building this container', rtype=str) def get_builder_name(self, **kwargs): ''' Get the name a Builder should be given ''' container = getargs('container', kwargs) obj_mapper = self.get_map(container) if obj_mapper is None: raise ValueError('No ObjectMapper found for container of type %s' % str(container.__class__.__name__)) else: return obj_mapper.get_builder_name(container) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/build/map.py0000644000655200065520000000061200000000000016513 0ustar00circlecicircleci# this prevents breaking of code that imports these classes directly from map.py from .manager import Proxy, BuildManager, TypeSource, TypeMap # noqa: F401 from .objectmapper import ObjectMapper # noqa: F401 import warnings warnings.warn('Classes in map.py should be imported from hdmf.build. Importing from hdmf.build.map will be removed ' 'in HDMF 3.0.', DeprecationWarning) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/build/objectmapper.py0000644000655200065520000017735500000000000020434 0ustar00circlecicircleciimport logging import re import warnings from collections import OrderedDict from copy import copy from datetime import datetime import numpy as np from .builders import DatasetBuilder, GroupBuilder, LinkBuilder, Builder, ReferenceBuilder, RegionBuilder, BaseBuilder from .errors import (BuildError, OrphanContainerBuildError, ReferenceTargetNotBuiltError, ContainerConfigurationError, ConstructError) from .manager import Proxy, BuildManager from .warnings import MissingRequiredBuildWarning, DtypeConversionWarning, IncorrectQuantityBuildWarning from ..container import AbstractContainer, Data, DataRegion from ..data_utils import DataIO, AbstractDataChunkIterator from ..query import ReferenceResolver from ..spec import Spec, AttributeSpec, DatasetSpec, GroupSpec, LinkSpec, NAME_WILDCARD, RefSpec from ..spec.spec import BaseStorageSpec from ..utils import docval, getargs, ExtenderMeta, get_docval _const_arg = '__constructor_arg' @docval({'name': 'name', 'type': str, 'doc': 'the name of the constructor argument'}, is_method=False) def _constructor_arg(**kwargs): '''Decorator to override the default mapping scheme for a given constructor argument. Decorate ObjectMapper methods with this function when extending ObjectMapper to override the default scheme for mapping between AbstractContainer and Builder objects. The decorated method should accept as its first argument the Builder object that is being mapped. The method should return the value to be passed to the target AbstractContainer class constructor argument given by *name*. ''' name = getargs('name', kwargs) def _dec(func): setattr(func, _const_arg, name) return func return _dec _obj_attr = '__object_attr' @docval({'name': 'name', 'type': str, 'doc': 'the name of the constructor argument'}, is_method=False) def _object_attr(**kwargs): '''Decorator to override the default mapping scheme for a given object attribute. Decorate ObjectMapper methods with this function when extending ObjectMapper to override the default scheme for mapping between AbstractContainer and Builder objects. The decorated method should accept as its first argument the AbstractContainer object that is being mapped. The method should return the child Builder object (or scalar if the object attribute corresponds to an AttributeSpec) that represents the attribute given by *name*. ''' name = getargs('name', kwargs) def _dec(func): setattr(func, _obj_attr, name) return func return _dec def _unicode(s): """ A helper function for converting to Unicode """ if isinstance(s, str): return s elif isinstance(s, bytes): return s.decode('utf-8') else: raise ValueError("Expected unicode or ascii string, got %s" % type(s)) def _ascii(s): """ A helper function for converting to ASCII """ if isinstance(s, str): return s.encode('ascii', 'backslashreplace') elif isinstance(s, bytes): return s else: raise ValueError("Expected unicode or ascii string, got %s" % type(s)) class ObjectMapper(metaclass=ExtenderMeta): '''A class for mapping between Spec objects and AbstractContainer attributes ''' # mapping from spec dtypes to numpy dtypes or functions for conversion of values to spec dtypes # make sure keys are consistent between hdmf.spec.spec.DtypeHelper.primary_dtype_synonyms, # hdmf.build.objectmapper.ObjectMapper.__dtypes, hdmf.build.manager.TypeMap._spec_dtype_map, # hdmf.validate.validator.__allowable, and backend dtype maps __dtypes = { "float": np.float32, "float32": np.float32, "double": np.float64, "float64": np.float64, "long": np.int64, "int64": np.int64, "int": np.int32, "int32": np.int32, "short": np.int16, "int16": np.int16, "int8": np.int8, "uint": np.uint32, "uint64": np.uint64, "uint32": np.uint32, "uint16": np.uint16, "uint8": np.uint8, "bool": np.bool_, "text": _unicode, "utf": _unicode, "utf8": _unicode, "utf-8": _unicode, "ascii": _ascii, "bytes": _ascii, "isodatetime": _ascii, "datetime": _ascii, } __no_convert = set() @classmethod def __resolve_numeric_dtype(cls, given, specified): """ Determine the dtype to use from the dtype of the given value and the specified dtype. This amounts to determining the greater precision of the two arguments, but also checks to make sure the same base dtype is being used. A warning is raised if the base type of the specified dtype differs from the base type of the given dtype and a conversion will result (e.g., float32 -> uint32). """ g = np.dtype(given) s = np.dtype(specified) if g == s: return s.type, None if g.itemsize <= s.itemsize: # given type has precision < precision of specified type # note: this allows float32 -> int32, bool -> int8, int16 -> uint16 which may involve buffer overflows, # truncated values, and other unexpected consequences. warning_msg = ('Value with data type %s is being converted to data type %s as specified.' % (g.name, s.name)) return s.type, warning_msg elif g.name[:3] == s.name[:3]: return g.type, None # same base type, use higher-precision given type else: if np.issubdtype(s, np.unsignedinteger): # e.g.: given int64 and spec uint32, return uint64. given float32 and spec uint8, return uint32. ret_type = np.dtype('uint' + str(int(g.itemsize * 8))) warning_msg = ('Value with data type %s is being converted to data type %s (min specification: %s).' % (g.name, ret_type.name, s.name)) return ret_type.type, warning_msg if np.issubdtype(s, np.floating): # e.g.: given int64 and spec float32, return float64. given uint64 and spec float32, return float32. ret_type = np.dtype('float' + str(max(int(g.itemsize * 8), 32))) warning_msg = ('Value with data type %s is being converted to data type %s (min specification: %s).' % (g.name, ret_type.name, s.name)) return ret_type.type, warning_msg if np.issubdtype(s, np.integer): # e.g.: given float64 and spec int8, return int64. given uint32 and spec int8, return int32. ret_type = np.dtype('int' + str(int(g.itemsize * 8))) warning_msg = ('Value with data type %s is being converted to data type %s (min specification: %s).' % (g.name, ret_type.name, s.name)) return ret_type.type, warning_msg if s.type is np.bool_: msg = "expected %s, received %s - must supply %s" % (s.name, g.name, s.name) raise ValueError(msg) # all numeric types in __dtypes should be caught by the above raise ValueError('Unsupported conversion to specification data type: %s' % s.name) @classmethod def no_convert(cls, obj_type): """ Specify an object type that ObjectMappers should not convert. """ cls.__no_convert.add(obj_type) @classmethod # noqa: C901 def convert_dtype(cls, spec, value, spec_dtype=None): # noqa: C901 """ Convert values to the specified dtype. For example, if a literal int is passed in to a field that is specified as a unsigned integer, this function will convert the Python int to a numpy unsigned int. :param spec: The DatasetSpec or AttributeSpec to which this value is being applied :param value: The value being converted to the spec dtype :param spec_dtype: Optional override of the dtype in spec.dtype. Used to specify the parent dtype when the given extended spec lacks a dtype. :return: The function returns a tuple consisting of 1) the value, and 2) the data type. The value is returned as the function may convert the input value to comply with the dtype specified in the schema. """ if spec_dtype is None: spec_dtype = spec.dtype ret, ret_dtype = cls.__check_edgecases(spec, value, spec_dtype) if ret is not None or ret_dtype is not None: return ret, ret_dtype # spec_dtype is a string, spec_dtype_type is a type or the conversion helper functions _unicode or _ascii spec_dtype_type = cls.__dtypes[spec_dtype] warning_msg = None if isinstance(value, np.ndarray): if spec_dtype_type is _unicode: ret = value.astype('U') ret_dtype = "utf8" elif spec_dtype_type is _ascii: ret = value.astype('S') ret_dtype = "ascii" else: dtype_func, warning_msg = cls.__resolve_numeric_dtype(value.dtype, spec_dtype_type) if value.dtype == dtype_func: ret = value else: ret = value.astype(dtype_func) ret_dtype = ret.dtype.type elif isinstance(value, (tuple, list)): if len(value) == 0: if spec_dtype_type is _unicode: ret_dtype = 'utf8' elif spec_dtype_type is _ascii: ret_dtype = 'ascii' else: ret_dtype = spec_dtype_type return value, ret_dtype ret = list() for elem in value: tmp, tmp_dtype = cls.convert_dtype(spec, elem, spec_dtype) ret.append(tmp) ret = type(value)(ret) ret_dtype = tmp_dtype elif isinstance(value, AbstractDataChunkIterator): ret = value if spec_dtype_type is _unicode: ret_dtype = "utf8" elif spec_dtype_type is _ascii: ret_dtype = "ascii" else: ret_dtype, warning_msg = cls.__resolve_numeric_dtype(value.dtype, spec_dtype_type) else: if spec_dtype_type in (_unicode, _ascii): ret_dtype = 'ascii' if spec_dtype_type is _unicode: ret_dtype = 'utf8' ret = spec_dtype_type(value) else: dtype_func, warning_msg = cls.__resolve_numeric_dtype(type(value), spec_dtype_type) ret = dtype_func(value) ret_dtype = type(ret) if warning_msg: full_warning_msg = "Spec '%s': %s" % (spec.path, warning_msg) warnings.warn(full_warning_msg, DtypeConversionWarning) return ret, ret_dtype @classmethod def __check_convert_numeric(cls, value_type): # dtype 'numeric' allows only ints, floats, and uints value_dtype = np.dtype(value_type) if not (np.issubdtype(value_dtype, np.unsignedinteger) or np.issubdtype(value_dtype, np.floating) or np.issubdtype(value_dtype, np.integer)): raise ValueError("Cannot convert from %s to 'numeric' specification dtype." % value_type) @classmethod # noqa: C901 def __check_edgecases(cls, spec, value, spec_dtype): # noqa: C901 """ Check edge cases in converting data to a dtype """ if value is None: dt = spec_dtype if isinstance(dt, RefSpec): dt = dt.reftype return None, dt if isinstance(spec_dtype, list): # compound dtype - Since the I/O layer needs to determine how to handle these, # return the list of DtypeSpecs return value, spec_dtype if isinstance(value, DataIO): return value, cls.convert_dtype(spec, value.data, spec_dtype)[1] if spec_dtype is None or spec_dtype == 'numeric' or type(value) in cls.__no_convert: # infer type from value if hasattr(value, 'dtype'): # covers numpy types, AbstractDataChunkIterator if spec_dtype == 'numeric': cls.__check_convert_numeric(value.dtype.type) if np.issubdtype(value.dtype, np.str_): ret_dtype = 'utf8' elif np.issubdtype(value.dtype, np.string_): ret_dtype = 'ascii' else: ret_dtype = value.dtype.type return value, ret_dtype if isinstance(value, (list, tuple)): if len(value) == 0: msg = "Cannot infer dtype of empty list or tuple. Please use numpy array with specified dtype." raise ValueError(msg) return value, cls.__check_edgecases(spec, value[0], spec_dtype)[1] # infer dtype from first element ret_dtype = type(value) if spec_dtype == 'numeric': cls.__check_convert_numeric(ret_dtype) if ret_dtype is str: ret_dtype = 'utf8' elif ret_dtype is bytes: ret_dtype = 'ascii' return value, ret_dtype if isinstance(spec_dtype, RefSpec): if not isinstance(value, ReferenceBuilder): msg = "got RefSpec for value of type %s" % type(value) raise ValueError(msg) return value, spec_dtype if spec_dtype is not None and spec_dtype not in cls.__dtypes: # pragma: no cover msg = "unrecognized dtype: %s -- cannot convert value" % spec_dtype raise ValueError(msg) return None, None _const_arg = '__constructor_arg' @staticmethod @docval({'name': 'name', 'type': str, 'doc': 'the name of the constructor argument'}, is_method=False) def constructor_arg(**kwargs): '''Decorator to override the default mapping scheme for a given constructor argument. Decorate ObjectMapper methods with this function when extending ObjectMapper to override the default scheme for mapping between AbstractContainer and Builder objects. The decorated method should accept as its first argument the Builder object that is being mapped. The method should return the value to be passed to the target AbstractContainer class constructor argument given by *name*. ''' name = getargs('name', kwargs) return _constructor_arg(name) _obj_attr = '__object_attr' @staticmethod @docval({'name': 'name', 'type': str, 'doc': 'the name of the constructor argument'}, is_method=False) def object_attr(**kwargs): '''Decorator to override the default mapping scheme for a given object attribute. Decorate ObjectMapper methods with this function when extending ObjectMapper to override the default scheme for mapping between AbstractContainer and Builder objects. The decorated method should accept as its first argument the AbstractContainer object that is being mapped. The method should return the child Builder object (or scalar if the object attribute corresponds to an AttributeSpec) that represents the attribute given by *name*. ''' name = getargs('name', kwargs) return _object_attr(name) @staticmethod def __is_attr(attr_val): return hasattr(attr_val, _obj_attr) @staticmethod def __get_obj_attr(attr_val): return getattr(attr_val, _obj_attr) @staticmethod def __is_constructor_arg(attr_val): return hasattr(attr_val, _const_arg) @staticmethod def __get_cargname(attr_val): return getattr(attr_val, _const_arg) @ExtenderMeta.post_init def __gather_procedures(cls, name, bases, classdict): if hasattr(cls, 'constructor_args'): cls.constructor_args = copy(cls.constructor_args) else: cls.constructor_args = dict() if hasattr(cls, 'obj_attrs'): cls.obj_attrs = copy(cls.obj_attrs) else: cls.obj_attrs = dict() for name, func in cls.__dict__.items(): if cls.__is_constructor_arg(func): cls.constructor_args[cls.__get_cargname(func)] = getattr(cls, name) elif cls.__is_attr(func): cls.obj_attrs[cls.__get_obj_attr(func)] = getattr(cls, name) @docval({'name': 'spec', 'type': (DatasetSpec, GroupSpec), 'doc': 'The specification for mapping objects to builders'}) def __init__(self, **kwargs): """ Create a map from AbstractContainer attributes to specifications """ self.logger = logging.getLogger('%s.%s' % (self.__class__.__module__, self.__class__.__qualname__)) spec = getargs('spec', kwargs) self.__spec = spec self.__data_type_key = spec.type_key() self.__spec2attr = dict() self.__attr2spec = dict() self.__spec2carg = dict() self.__carg2spec = dict() self.__map_spec(spec) @property def spec(self): ''' the Spec used in this ObjectMapper ''' return self.__spec @_constructor_arg('name') def get_container_name(self, *args): builder = args[0] return builder.name @classmethod @docval({'name': 'spec', 'type': Spec, 'doc': 'the specification to get the name for'}) def convert_dt_name(cls, **kwargs): '''Construct the attribute name corresponding to a specification''' spec = getargs('spec', kwargs) name = cls.__get_data_type(spec) s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() if name[-1] != 's' and spec.is_many(): name += 's' return name @classmethod def __get_fields(cls, name_stack, all_names, spec): name = spec.name if spec.name is None: name = cls.convert_dt_name(spec) name_stack.append(name) name = '__'.join(name_stack) # TODO address potential name clashes, e.g., quantity '*' subgroups and links of same data_type_inc will # have the same name all_names[name] = spec if isinstance(spec, BaseStorageSpec): if not (spec.data_type_def is None and spec.data_type_inc is None): # don't get names for components in data_types name_stack.pop() return for subspec in spec.attributes: cls.__get_fields(name_stack, all_names, subspec) if isinstance(spec, GroupSpec): for subspec in spec.datasets: cls.__get_fields(name_stack, all_names, subspec) for subspec in spec.groups: cls.__get_fields(name_stack, all_names, subspec) for subspec in spec.links: cls.__get_fields(name_stack, all_names, subspec) name_stack.pop() @classmethod @docval({'name': 'spec', 'type': Spec, 'doc': 'the specification to get the object attribute names for'}) def get_attr_names(cls, **kwargs): '''Get the attribute names for each subspecification in a Spec''' spec = getargs('spec', kwargs) names = OrderedDict() for subspec in spec.attributes: cls.__get_fields(list(), names, subspec) if isinstance(spec, GroupSpec): for subspec in spec.groups: cls.__get_fields(list(), names, subspec) for subspec in spec.datasets: cls.__get_fields(list(), names, subspec) for subspec in spec.links: cls.__get_fields(list(), names, subspec) return names def __map_spec(self, spec): attr_names = self.get_attr_names(spec) for k, v in attr_names.items(): self.map_spec(k, v) @docval({"name": "attr_name", "type": str, "doc": "the name of the object to map"}, {"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}) def map_attr(self, **kwargs): """ Map an attribute to spec. Use this to override default behavior """ attr_name, spec = getargs('attr_name', 'spec', kwargs) self.__spec2attr[spec] = attr_name self.__attr2spec[attr_name] = spec @docval({"name": "attr_name", "type": str, "doc": "the name of the attribute"}) def get_attr_spec(self, **kwargs): """ Return the Spec for a given attribute """ attr_name = getargs('attr_name', kwargs) return self.__attr2spec.get(attr_name) @docval({"name": "carg_name", "type": str, "doc": "the name of the constructor argument"}) def get_carg_spec(self, **kwargs): """ Return the Spec for a given constructor argument """ carg_name = getargs('carg_name', kwargs) return self.__carg2spec.get(carg_name) @docval({"name": "const_arg", "type": str, "doc": "the name of the constructor argument to map"}, {"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}) def map_const_arg(self, **kwargs): """ Map an attribute to spec. Use this to override default behavior """ const_arg, spec = getargs('const_arg', 'spec', kwargs) self.__spec2carg[spec] = const_arg self.__carg2spec[const_arg] = spec @docval({"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}) def unmap(self, **kwargs): """ Removing any mapping for a specification. Use this to override default mapping """ spec = getargs('spec', kwargs) self.__spec2attr.pop(spec, None) self.__spec2carg.pop(spec, None) @docval({"name": "attr_carg", "type": str, "doc": "the constructor argument/object attribute to map this spec to"}, {"name": "spec", "type": Spec, "doc": "the spec to map the attribute to"}) def map_spec(self, **kwargs): """ Map the given specification to the construct argument and object attribute """ spec, attr_carg = getargs('spec', 'attr_carg', kwargs) self.map_const_arg(attr_carg, spec) self.map_attr(attr_carg, spec) def __get_override_carg(self, *args): name = args[0] remaining_args = tuple(args[1:]) if name in self.constructor_args: self.logger.debug(" Calling override function for constructor argument '%s'" % name) func = self.constructor_args[name] return func(self, *remaining_args) return None def __get_override_attr(self, name, container, manager): if name in self.obj_attrs: self.logger.debug(" Calling override function for attribute '%s'" % name) func = self.obj_attrs[name] return func(self, container, manager) return None @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute for"}, returns='the attribute name', rtype=str) def get_attribute(self, **kwargs): ''' Get the object attribute name for the given Spec ''' spec = getargs('spec', kwargs) val = self.__spec2attr.get(spec, None) return val @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, {"name": "container", "type": AbstractContainer, "doc": "the container to get the attribute value from"}, {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, returns='the value of the attribute') def get_attr_value(self, **kwargs): ''' Get the value of the attribute corresponding to this spec from the given container ''' spec, container, manager = getargs('spec', 'container', 'manager', kwargs) attr_name = self.get_attribute(spec) if attr_name is None: return None attr_val = self.__get_override_attr(attr_name, container, manager) if attr_val is None: try: attr_val = getattr(container, attr_name) except AttributeError: msg = ("%s '%s' does not have attribute '%s' for mapping to spec: %s" % (container.__class__.__name__, container.name, attr_name, spec)) raise ContainerConfigurationError(msg) if attr_val is not None: attr_val = self.__convert_string(attr_val, spec) spec_dt = self.__get_data_type(spec) if spec_dt is not None: try: attr_val = self.__filter_by_spec_dt(attr_val, spec_dt, manager) except ValueError as e: msg = ("%s '%s' attribute '%s' has unexpected type." % (container.__class__.__name__, container.name, attr_name)) raise ContainerConfigurationError(msg) from e # else: attr_val is an attribute on the Container and its value is None # attr_val can be None, an AbstractContainer, or a list of AbstractContainers return attr_val @classmethod def __get_data_type(cls, spec): ret = None if isinstance(spec, LinkSpec): ret = spec.target_type elif isinstance(spec, BaseStorageSpec): if spec.data_type_def is not None: ret = spec.data_type_def elif spec.data_type_inc is not None: ret = spec.data_type_inc # else, untyped group/dataset spec # else, attribute spec return ret def __convert_string(self, value, spec): """Convert string types to the specified dtype.""" ret = value if isinstance(spec, AttributeSpec): if 'text' in spec.dtype: if spec.shape is not None or spec.dims is not None: ret = list(map(str, value)) else: ret = str(value) elif isinstance(spec, DatasetSpec): # TODO: make sure we can handle specs with data_type_inc set if spec.data_type_inc is None and spec.dtype is not None: string_type = None if 'text' in spec.dtype: string_type = str elif 'ascii' in spec.dtype: string_type = bytes elif 'isodatetime' in spec.dtype: string_type = datetime.isoformat if string_type is not None: if spec.shape is not None or spec.dims is not None: ret = list(map(string_type, value)) else: ret = string_type(value) # copy over any I/O parameters if they were specified if isinstance(value, DataIO): params = value.get_io_params() params['data'] = ret ret = value.__class__(**params) return ret def __filter_by_spec_dt(self, attr_value, spec_dt, build_manager): """Return a list of containers that match the spec data type. If attr_value is a container that does not match the spec data type, then None is returned. If attr_value is a collection, then a list of only the containers in the collection that match the spec data type are returned. Otherwise, attr_value is returned unchanged. spec_dt is a string representing a spec data type. Return None, an AbstractContainer, or a list of AbstractContainers """ if isinstance(attr_value, AbstractContainer): if build_manager.is_sub_data_type(attr_value, spec_dt): return attr_value else: return None ret = attr_value if isinstance(attr_value, (list, tuple, set, dict)): if isinstance(attr_value, dict): attr_values = attr_value.values() else: attr_values = attr_value ret = [] # NOTE: this will test collections of non-containers element-wise (e.g. lists of lists of ints) for c in attr_values: if self.__filter_by_spec_dt(c, spec_dt, build_manager) is not None: ret.append(c) if len(ret) == 0: ret = None else: raise ValueError("Unexpected type for attr_value: %s. Only AbstractContainer, list, tuple, set, dict, are " "allowed." % type(attr_value)) return ret def __check_quantity(self, attr_value, spec, container): if attr_value is None and spec.required: attr_name = self.get_attribute(spec) msg = ("%s '%s' is missing required value for attribute '%s'." % (container.__class__.__name__, container.name, attr_name)) warnings.warn(msg, MissingRequiredBuildWarning) self.logger.debug('MissingRequiredBuildWarning: ' + msg) elif attr_value is not None and self.__get_data_type(spec) is not None: # quantity is valid only for specs with a data type or target type if isinstance(attr_value, AbstractContainer): attr_value = [attr_value] n = len(attr_value) if (n and isinstance(attr_value[0], AbstractContainer) and ((n > 1 and not spec.is_many()) or (isinstance(spec.quantity, int) and n != spec.quantity))): attr_name = self.get_attribute(spec) msg = ("%s '%s' has %d values for attribute '%s' but spec allows %s." % (container.__class__.__name__, container.name, n, attr_name, repr(spec.quantity))) warnings.warn(msg, IncorrectQuantityBuildWarning) self.logger.debug('IncorrectQuantityBuildWarning: ' + msg) @docval({"name": "spec", "type": Spec, "doc": "the spec to get the constructor argument for"}, returns="the name of the constructor argument", rtype=str) def get_const_arg(self, **kwargs): ''' Get the constructor argument for the given Spec ''' spec = getargs('spec', kwargs) return self.__spec2carg.get(spec, None) @docval({"name": "container", "type": AbstractContainer, "doc": "the container to convert to a Builder"}, {"name": "manager", "type": BuildManager, "doc": "the BuildManager to use for managing this build"}, {"name": "parent", "type": GroupBuilder, "doc": "the parent of the resulting Builder", 'default': None}, {"name": "source", "type": str, "doc": "the source of container being built i.e. file path", 'default': None}, {"name": "builder", "type": BaseBuilder, "doc": "the Builder to build on", 'default': None}, {"name": "spec_ext", "type": BaseStorageSpec, "doc": "a spec extension", 'default': None}, {"name": "export", "type": bool, "doc": "whether this build is for exporting", 'default': False}, returns="the Builder representing the given AbstractContainer", rtype=Builder) def build(self, **kwargs): '''Convert an AbstractContainer to a Builder representation. References are not added but are queued to be added in the BuildManager. ''' container, manager, parent, source = getargs('container', 'manager', 'parent', 'source', kwargs) builder, spec_ext, export = getargs('builder', 'spec_ext', 'export', kwargs) name = manager.get_builder_name(container) if isinstance(self.__spec, GroupSpec): self.logger.debug("Building %s '%s' as a group (source: %s)" % (container.__class__.__name__, container.name, repr(source))) if builder is None: builder = GroupBuilder(name, parent=parent, source=source) self.__add_datasets(builder, self.__spec.datasets, container, manager, source, export) self.__add_groups(builder, self.__spec.groups, container, manager, source, export) self.__add_links(builder, self.__spec.links, container, manager, source, export) else: if builder is None: if not isinstance(container, Data): msg = "'container' must be of type Data with DatasetSpec" raise ValueError(msg) spec_dtype, spec_shape, spec = self.__check_dset_spec(self.spec, spec_ext) if isinstance(spec_dtype, RefSpec): self.logger.debug("Building %s '%s' as a dataset of references (source: %s)" % (container.__class__.__name__, container.name, repr(source))) # create dataset builder with data=None as a placeholder. fill in with refs later builder = DatasetBuilder(name, data=None, parent=parent, source=source, dtype=spec_dtype.reftype) manager.queue_ref(self.__set_dataset_to_refs(builder, spec_dtype, spec_shape, container, manager)) elif isinstance(spec_dtype, list): # a compound dataset self.logger.debug("Building %s '%s' as a dataset of compound dtypes (source: %s)" % (container.__class__.__name__, container.name, repr(source))) # create dataset builder with data=None, dtype=None as a placeholder. fill in with refs later builder = DatasetBuilder(name, data=None, parent=parent, source=source, dtype=spec_dtype) manager.queue_ref(self.__set_compound_dataset_to_refs(builder, spec, spec_dtype, container, manager)) else: # a regular dtype if spec_dtype is None and self.__is_reftype(container.data): self.logger.debug("Building %s '%s' containing references as a dataset of unspecified dtype " "(source: %s)" % (container.__class__.__name__, container.name, repr(source))) # an unspecified dtype and we were given references # create dataset builder with data=None as a placeholder. fill in with refs later builder = DatasetBuilder(name, data=None, parent=parent, source=source, dtype='object') manager.queue_ref(self.__set_untyped_dataset_to_refs(builder, container, manager)) else: # a dataset that has no references, pass the conversion off to the convert_dtype method self.logger.debug("Building %s '%s' as a dataset (source: %s)" % (container.__class__.__name__, container.name, repr(source))) try: # use spec_dtype from self.spec when spec_ext does not specify dtype bldr_data, dtype = self.convert_dtype(spec, container.data, spec_dtype=spec_dtype) except Exception as ex: msg = 'could not resolve dtype for %s \'%s\'' % (type(container).__name__, container.name) raise Exception(msg) from ex builder = DatasetBuilder(name, bldr_data, parent=parent, source=source, dtype=dtype) # Add attributes from the specification extension to the list of attributes all_attrs = self.__spec.attributes + getattr(spec_ext, 'attributes', tuple()) # If the spec_ext refines an existing attribute it will now appear twice in the list. The # refinement should only be relevant for validation (not for write). To avoid problems with the # write we here remove duplicates and keep the original spec of the two to make write work. # TODO: We should add validation in the AttributeSpec to make sure refinements are valid # TODO: Check the BuildManager as refinements should probably be resolved rather than be passed in via spec_ext all_attrs = list({a.name: a for a in all_attrs[::-1]}.values()) self.__add_attributes(builder, all_attrs, container, manager, source, export) return builder def __check_dset_spec(self, orig, ext): """ Check a dataset spec against a refining spec to see which dtype and shape should be used """ dtype = orig.dtype shape = orig.shape spec = orig if ext is not None: if ext.dtype is not None: dtype = ext.dtype if ext.shape is not None: shape = ext.shape spec = ext return dtype, shape, spec def __is_reftype(self, data): if (isinstance(data, AbstractDataChunkIterator) or (isinstance(data, DataIO) and isinstance(data.data, AbstractDataChunkIterator))): return False tmp = data while hasattr(tmp, '__len__') and not isinstance(tmp, (AbstractContainer, str, bytes)): tmptmp = None for t in tmp: # In case of a numeric array stop the iteration at the first element to avoid long-running loop if isinstance(t, (int, float, complex, bool)): break if hasattr(t, '__len__') and len(t) > 0 and not isinstance(t, (AbstractContainer, str, bytes)): tmptmp = tmp[0] break if tmptmp is not None: break else: if len(tmp) == 0: tmp = None else: tmp = tmp[0] if isinstance(tmp, AbstractContainer): return True else: return False def __set_dataset_to_refs(self, builder, dtype, shape, container, build_manager): self.logger.debug("Queueing set dataset of references %s '%s' to reference builder(s)" % (builder.__class__.__name__, builder.name)) def _filler(): builder.data = self.__get_ref_builder(builder, dtype, shape, container, build_manager) return _filler def __set_compound_dataset_to_refs(self, builder, spec, spec_dtype, container, build_manager): self.logger.debug("Queueing convert compound dataset %s '%s' and set any references to reference builders" % (builder.__class__.__name__, builder.name)) def _filler(): self.logger.debug("Converting compound dataset %s '%s' and setting any references to reference builders" % (builder.__class__.__name__, builder.name)) # convert the reference part(s) of a compound dataset to ReferenceBuilders, row by row refs = [(i, subt) for i, subt in enumerate(spec_dtype) if isinstance(subt.dtype, RefSpec)] bldr_data = list() for i, row in enumerate(container.data): tmp = list(row) for j, subt in refs: tmp[j] = self.__get_ref_builder(builder, subt.dtype, None, row[j], build_manager) bldr_data.append(tuple(tmp)) builder.data = bldr_data return _filler def __set_untyped_dataset_to_refs(self, builder, container, build_manager): self.logger.debug("Queueing set untyped dataset %s '%s' to reference builders" % (builder.__class__.__name__, builder.name)) def _filler(): self.logger.debug("Setting untyped dataset %s '%s' to list of reference builders" % (builder.__class__.__name__, builder.name)) bldr_data = list() for d in container.data: if d is None: bldr_data.append(None) else: target_builder = self.__get_target_builder(d, build_manager, builder) bldr_data.append(ReferenceBuilder(target_builder)) builder.data = bldr_data return _filler def __get_ref_builder(self, builder, dtype, shape, container, build_manager): bldr_data = None if dtype.is_region(): if shape is None: if not isinstance(container, DataRegion): msg = "'container' must be of type DataRegion if spec represents region reference" raise ValueError(msg) self.logger.debug("Setting %s '%s' data to region reference builder" % (builder.__class__.__name__, builder.name)) target_builder = self.__get_target_builder(container.data, build_manager, builder) bldr_data = RegionBuilder(container.region, target_builder) else: self.logger.debug("Setting %s '%s' data to list of region reference builders" % (builder.__class__.__name__, builder.name)) bldr_data = list() for d in container.data: target_builder = self.__get_target_builder(d.target, build_manager, builder) bldr_data.append(RegionBuilder(d.slice, target_builder)) else: self.logger.debug("Setting object reference dataset on %s '%s' data" % (builder.__class__.__name__, builder.name)) if isinstance(container, Data): self.logger.debug("Setting %s '%s' data to list of reference builders" % (builder.__class__.__name__, builder.name)) bldr_data = list() for d in container.data: target_builder = self.__get_target_builder(d, build_manager, builder) bldr_data.append(ReferenceBuilder(target_builder)) else: self.logger.debug("Setting %s '%s' data to reference builder" % (builder.__class__.__name__, builder.name)) target_builder = self.__get_target_builder(container, build_manager, builder) bldr_data = ReferenceBuilder(target_builder) return bldr_data def __get_target_builder(self, container, build_manager, builder): target_builder = build_manager.get_builder(container) if target_builder is None: raise ReferenceTargetNotBuiltError(builder, container) return target_builder def __add_attributes(self, builder, attributes, container, build_manager, source, export): if attributes: self.logger.debug("Adding attributes from %s '%s' to %s '%s'" % (container.__class__.__name__, container.name, builder.__class__.__name__, builder.name)) for spec in attributes: self.logger.debug(" Adding attribute for spec name: %s (dtype: %s)" % (repr(spec.name), spec.dtype.__class__.__name__)) if spec.value is not None: attr_value = spec.value else: attr_value = self.get_attr_value(spec, container, build_manager) if attr_value is None: attr_value = spec.default_value attr_value = self.__check_ref_resolver(attr_value) self.__check_quantity(attr_value, spec, container) if attr_value is None: self.logger.debug(" Skipping empty attribute") continue if isinstance(spec.dtype, RefSpec): if not self.__is_reftype(attr_value): msg = ("invalid type for reference '%s' (%s) - must be AbstractContainer" % (spec.name, type(attr_value))) raise ValueError(msg) build_manager.queue_ref(self.__set_attr_to_ref(builder, attr_value, build_manager, spec)) continue else: try: attr_value, attr_dtype = self.convert_dtype(spec, attr_value) except Exception as ex: msg = 'could not convert %s for %s %s' % (spec.name, type(container).__name__, container.name) raise BuildError(builder, msg) from ex # do not write empty or null valued objects self.__check_quantity(attr_value, spec, container) if attr_value is None: self.logger.debug(" Skipping empty attribute") continue builder.set_attribute(spec.name, attr_value) def __set_attr_to_ref(self, builder, attr_value, build_manager, spec): self.logger.debug("Queueing set reference attribute on %s '%s' attribute '%s' to %s" % (builder.__class__.__name__, builder.name, spec.name, attr_value.__class__.__name__)) def _filler(): self.logger.debug("Setting reference attribute on %s '%s' attribute '%s' to %s" % (builder.__class__.__name__, builder.name, spec.name, attr_value.__class__.__name__)) target_builder = self.__get_target_builder(attr_value, build_manager, builder) ref_attr_value = ReferenceBuilder(target_builder) builder.set_attribute(spec.name, ref_attr_value) return _filler def __add_links(self, builder, links, container, build_manager, source, export): if links: self.logger.debug("Adding links from %s '%s' to %s '%s'" % (container.__class__.__name__, container.name, builder.__class__.__name__, builder.name)) for spec in links: self.logger.debug(" Adding link for spec name: %s, target_type: %s" % (repr(spec.name), repr(spec.target_type))) attr_value = self.get_attr_value(spec, container, build_manager) self.__check_quantity(attr_value, spec, container) if attr_value is None: self.logger.debug(" Skipping link - no attribute value") continue self.__add_containers(builder, spec, attr_value, build_manager, source, container, export) def __add_datasets(self, builder, datasets, container, build_manager, source, export): if datasets: self.logger.debug("Adding datasets from %s '%s' to %s '%s'" % (container.__class__.__name__, container.name, builder.__class__.__name__, builder.name)) for spec in datasets: self.logger.debug(" Adding dataset for spec name: %s (dtype: %s)" % (repr(spec.name), spec.dtype.__class__.__name__)) attr_value = self.get_attr_value(spec, container, build_manager) self.__check_quantity(attr_value, spec, container) if attr_value is None: self.logger.debug(" Skipping dataset - no attribute value") continue attr_value = self.__check_ref_resolver(attr_value) if isinstance(attr_value, DataIO) and attr_value.data is None: self.logger.debug(" Skipping dataset - attribute is dataio or has no data") continue if isinstance(attr_value, LinkBuilder): self.logger.debug(" Adding %s '%s' for spec name: %s, %s: %s, %s: %s" % (attr_value.name, attr_value.__class__.__name__, repr(spec.name), spec.def_key(), repr(spec.data_type_def), spec.inc_key(), repr(spec.data_type_inc))) builder.set_link(attr_value) # add the existing builder elif spec.data_type_def is None and spec.data_type_inc is None: # untyped, named dataset if spec.name in builder.datasets: sub_builder = builder.datasets[spec.name] self.logger.debug(" Retrieving existing DatasetBuilder '%s' for spec name %s and adding " "attributes" % (sub_builder.name, repr(spec.name))) else: self.logger.debug(" Converting untyped dataset for spec name %s to spec dtype %s" % (repr(spec.name), repr(spec.dtype))) try: data, dtype = self.convert_dtype(spec, attr_value) except Exception as ex: msg = 'could not convert \'%s\' for %s \'%s\'' msg = msg % (spec.name, type(container).__name__, container.name) raise BuildError(builder, msg) from ex self.logger.debug(" Adding untyped dataset for spec name %s and adding attributes" % repr(spec.name)) sub_builder = DatasetBuilder(spec.name, data, parent=builder, source=source, dtype=dtype) builder.set_dataset(sub_builder) self.__add_attributes(sub_builder, spec.attributes, container, build_manager, source, export) else: self.logger.debug(" Adding typed dataset for spec name: %s, %s: %s, %s: %s" % (repr(spec.name), spec.def_key(), repr(spec.data_type_def), spec.inc_key(), repr(spec.data_type_inc))) self.__add_containers(builder, spec, attr_value, build_manager, source, container, export) def __add_groups(self, builder, groups, container, build_manager, source, export): if groups: self.logger.debug("Adding groups from %s '%s' to %s '%s'" % (container.__class__.__name__, container.name, builder.__class__.__name__, builder.name)) for spec in groups: if spec.data_type_def is None and spec.data_type_inc is None: self.logger.debug(" Adding untyped group for spec name: %s" % repr(spec.name)) # we don't need to get attr_name since any named group does not have the concept of value sub_builder = builder.groups.get(spec.name) if sub_builder is None: sub_builder = GroupBuilder(spec.name, source=source) self.__add_attributes(sub_builder, spec.attributes, container, build_manager, source, export) self.__add_datasets(sub_builder, spec.datasets, container, build_manager, source, export) self.__add_links(sub_builder, spec.links, container, build_manager, source, export) self.__add_groups(sub_builder, spec.groups, container, build_manager, source, export) empty = sub_builder.is_empty() if not empty or (empty and spec.required): if sub_builder.name not in builder.groups: builder.set_group(sub_builder) else: self.logger.debug(" Adding group for spec name: %s, %s: %s, %s: %s" % (repr(spec.name), spec.def_key(), repr(spec.data_type_def), spec.inc_key(), repr(spec.data_type_inc))) attr_value = self.get_attr_value(spec, container, build_manager) self.__check_quantity(attr_value, spec, container) if attr_value is not None: self.__add_containers(builder, spec, attr_value, build_manager, source, container, export) def __add_containers(self, builder, spec, value, build_manager, source, parent_container, export): if isinstance(value, AbstractContainer): self.logger.debug(" Adding container %s '%s' with parent %s '%s' to %s '%s'" % (value.__class__.__name__, value.name, parent_container.__class__.__name__, parent_container.name, builder.__class__.__name__, builder.name)) if value.parent is None: if (value.container_source == parent_container.container_source or build_manager.get_builder(value) is None): # value was removed (or parent not set) and there is a link to it in same file # or value was read from an external link raise OrphanContainerBuildError(builder, value) if value.modified or export: # writing a newly instantiated container (modified is False only after read) or as if it is newly # instantianted (export=True) self.logger.debug(" Building newly instantiated %s '%s'" % (value.__class__.__name__, value.name)) if isinstance(spec, BaseStorageSpec): new_builder = build_manager.build(value, source=source, spec_ext=spec, export=export) else: new_builder = build_manager.build(value, source=source, export=export) # use spec to determine what kind of HDF5 object this AbstractContainer corresponds to if isinstance(spec, LinkSpec) or value.parent is not parent_container: self.logger.debug(" Adding link to %s '%s' in %s '%s'" % (new_builder.__class__.__name__, new_builder.name, builder.__class__.__name__, builder.name)) builder.set_link(LinkBuilder(new_builder, name=spec.name, parent=builder)) elif isinstance(spec, DatasetSpec): self.logger.debug(" Adding dataset %s '%s' to %s '%s'" % (new_builder.__class__.__name__, new_builder.name, builder.__class__.__name__, builder.name)) builder.set_dataset(new_builder) else: self.logger.debug(" Adding subgroup %s '%s' to %s '%s'" % (new_builder.__class__.__name__, new_builder.name, builder.__class__.__name__, builder.name)) builder.set_group(new_builder) elif value.container_source: # make a link to an existing container if (value.container_source != parent_container.container_source or value.parent is not parent_container): self.logger.debug(" Building %s '%s' (container source: %s) and adding a link to it" % (value.__class__.__name__, value.name, value.container_source)) if isinstance(spec, BaseStorageSpec): new_builder = build_manager.build(value, source=source, spec_ext=spec, export=export) else: new_builder = build_manager.build(value, source=source, export=export) builder.set_link(LinkBuilder(new_builder, name=spec.name, parent=builder)) else: self.logger.debug(" Skipping build for %s '%s' because both it and its parents were read " "from the same source." % (value.__class__.__name__, value.name)) else: raise ValueError("Found unmodified AbstractContainer with no source - '%s' with parent '%s'" % (value.name, parent_container.name)) elif isinstance(value, list): for container in value: self.__add_containers(builder, spec, container, build_manager, source, parent_container, export) else: # pragma: no cover msg = ("Received %s, expected AbstractContainer or a list of AbstractContainers." % value.__class__.__name__) raise ValueError(msg) def __get_subspec_values(self, builder, spec, manager): ret = dict() # First get attributes attributes = builder.attributes for attr_spec in spec.attributes: attr_val = attributes.get(attr_spec.name) if attr_val is None: continue if isinstance(attr_val, (GroupBuilder, DatasetBuilder)): ret[attr_spec] = manager.construct(attr_val) elif isinstance(attr_val, RegionBuilder): # pragma: no cover raise ValueError("RegionReferences as attributes is not yet supported") elif isinstance(attr_val, ReferenceBuilder): ret[attr_spec] = manager.construct(attr_val.builder) else: ret[attr_spec] = attr_val if isinstance(spec, GroupSpec): if not isinstance(builder, GroupBuilder): # pragma: no cover raise ValueError("__get_subspec_values - must pass GroupBuilder with GroupSpec") # first aggregate links by data type and separate them # by group and dataset groups = dict(builder.groups) # make a copy so we can separate links datasets = dict(builder.datasets) # make a copy so we can separate links links = builder.links link_dt = dict() for link_builder in links.values(): target = link_builder.builder if isinstance(target, DatasetBuilder): datasets[link_builder.name] = target else: groups[link_builder.name] = target dt = manager.get_builder_dt(target) if dt is not None: link_dt.setdefault(dt, list()).append(target) # now assign links to their respective specification for subspec in spec.links: if subspec.name is not None and subspec.name in links: ret[subspec] = manager.construct(links[subspec.name].builder) else: sub_builder = link_dt.get(subspec.target_type) if sub_builder is not None: ret[subspec] = self.__flatten(sub_builder, subspec, manager) # now process groups and datasets self.__get_sub_builders(groups, spec.groups, manager, ret) self.__get_sub_builders(datasets, spec.datasets, manager, ret) elif isinstance(spec, DatasetSpec): if not isinstance(builder, DatasetBuilder): # pragma: no cover raise ValueError("__get_subspec_values - must pass DatasetBuilder with DatasetSpec") if (spec.shape is None and getattr(builder.data, 'shape', None) == (1,) and type(builder.data[0]) != np.void): # if a scalar dataset is expected and a 1-element non-compound dataset is given, then read the dataset builder['data'] = builder.data[0] # use dictionary reference instead of .data to bypass error ret[spec] = self.__check_ref_resolver(builder.data) return ret @staticmethod def __check_ref_resolver(data): """ Check if this dataset is a reference resolver, and invert it if so. """ if isinstance(data, ReferenceResolver): return data.invert() return data def __get_sub_builders(self, sub_builders, subspecs, manager, ret): # index builders by data_type builder_dt = dict() for g in sub_builders.values(): dt = manager.get_builder_dt(g) ns = manager.get_builder_ns(g) if dt is None or ns is None: continue for parent_dt in manager.namespace_catalog.get_hierarchy(ns, dt): builder_dt.setdefault(parent_dt, list()).append(g) for subspec in subspecs: # first get data type for the spec if subspec.data_type_def is not None: dt = subspec.data_type_def elif subspec.data_type_inc is not None: dt = subspec.data_type_inc else: dt = None # use name if we can, otherwise use data_data if subspec.name is None: sub_builder = builder_dt.get(dt) if sub_builder is not None: sub_builder = self.__flatten(sub_builder, subspec, manager) ret[subspec] = sub_builder else: sub_builder = sub_builders.get(subspec.name) if sub_builder is None: continue if dt is None: # recurse ret.update(self.__get_subspec_values(sub_builder, subspec, manager)) else: ret[subspec] = manager.construct(sub_builder) def __flatten(self, sub_builder, subspec, manager): tmp = [manager.construct(b) for b in sub_builder] if len(tmp) == 1 and not subspec.is_many(): tmp = tmp[0] return tmp @docval({'name': 'builder', 'type': (DatasetBuilder, GroupBuilder), 'doc': 'the builder to construct the AbstractContainer from'}, {'name': 'manager', 'type': BuildManager, 'doc': 'the BuildManager for this build'}, {'name': 'parent', 'type': (Proxy, AbstractContainer), 'doc': 'the parent AbstractContainer/Proxy for the AbstractContainer being built', 'default': None}) def construct(self, **kwargs): ''' Construct an AbstractContainer from the given Builder ''' builder, manager, parent = getargs('builder', 'manager', 'parent', kwargs) cls = manager.get_cls(builder) # gather all subspecs subspecs = self.__get_subspec_values(builder, self.spec, manager) # get the constructor argument that each specification corresponds to const_args = dict() # For Data container classes, we need to populate the data constructor argument since # there is no sub-specification that maps to that argument under the default logic if issubclass(cls, Data): if not isinstance(builder, DatasetBuilder): # pragma: no cover raise ValueError('Can only construct a Data object from a DatasetBuilder - got %s' % type(builder)) const_args['data'] = self.__check_ref_resolver(builder.data) for subspec, value in subspecs.items(): const_arg = self.get_const_arg(subspec) if const_arg is not None: if isinstance(subspec, BaseStorageSpec) and subspec.is_many(): existing_value = const_args.get(const_arg) if isinstance(existing_value, list): value = existing_value + value const_args[const_arg] = value # build kwargs for the constructor kwargs = dict() for const_arg in get_docval(cls.__init__): argname = const_arg['name'] override = self.__get_override_carg(argname, builder, manager) if override is not None: val = override elif argname in const_args: val = const_args[argname] else: continue kwargs[argname] = val try: obj = self.__new_container__(cls, builder.source, parent, builder.attributes.get(self.__spec.id_key()), **kwargs) except Exception as ex: msg = 'Could not construct %s object due to: %s' % (cls.__name__, ex) raise ConstructError(builder, msg) from ex return obj def __new_container__(self, cls, container_source, parent, object_id, **kwargs): """A wrapper function for ensuring a container gets everything set appropriately""" obj = cls.__new__(cls, container_source=container_source, parent=parent, object_id=object_id) obj.__init__(**kwargs) return obj @docval({'name': 'container', 'type': AbstractContainer, 'doc': 'the AbstractContainer to get the Builder name for'}) def get_builder_name(self, **kwargs): '''Get the name of a Builder that represents a AbstractContainer''' container = getargs('container', kwargs) if self.__spec.name not in (NAME_WILDCARD, None): ret = self.__spec.name else: ret = container.name return ret ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/build/warnings.py0000644000655200065520000000155100000000000017571 0ustar00circlecicircleciclass BuildWarning(UserWarning): """ Base class for warnings that are raised during the building of a container. """ pass class IncorrectQuantityBuildWarning(BuildWarning): """ Raised when a container field contains a number of groups/datasets/links that is not allowed by the spec. """ pass class MissingRequiredBuildWarning(BuildWarning): """ Raised when a required field is missing. """ pass class MissingRequiredWarning(MissingRequiredBuildWarning): """ Raised when a required field is missing. """ pass class OrphanContainerWarning(BuildWarning): """ Raised when a container is built without a parent. """ pass class DtypeConversionWarning(UserWarning): """ Raised when a value is converted to a different data type in order to match the specification. """ pass ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1766272 hdmf-3.1.1/src/hdmf/common/0000755000655200065520000000000000000000000015556 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/__init__.py0000644000655200065520000002127400000000000017675 0ustar00circlecicircleci'''This package will contain functions, classes, and objects for reading and writing data in according to the HDMF-common specification ''' import os.path from copy import deepcopy CORE_NAMESPACE = 'hdmf-common' EXP_NAMESPACE = 'hdmf-experimental' from ..spec import NamespaceCatalog # noqa: E402 from ..utils import docval, getargs, call_docval_func, get_docval, fmt_docval_args # noqa: E402 from ..backends.io import HDMFIO # noqa: E402 from ..backends.hdf5 import HDF5IO # noqa: E402 from ..validate import ValidatorMap # noqa: E402 from ..build import BuildManager, TypeMap # noqa: E402 from ..container import _set_exp # noqa: E402 # a global type map global __TYPE_MAP # a function to register a container classes with the global map @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to get the spec for'}, {'name': 'namespace', 'type': str, 'doc': 'the name of the namespace', 'default': CORE_NAMESPACE}, {"name": "container_cls", "type": type, "doc": "the class to map to the specified data_type", 'default': None}, is_method=False) def register_class(**kwargs): """Register an Container class to use for reading and writing a data_type from a specification If container_cls is not specified, returns a decorator for registering an Container subclass as the class for data_type in namespace. """ data_type, namespace, container_cls = getargs('data_type', 'namespace', 'container_cls', kwargs) if namespace == EXP_NAMESPACE: def _dec(cls): _set_exp(cls) __TYPE_MAP.register_container_type(namespace, data_type, cls) return cls else: def _dec(cls): __TYPE_MAP.register_container_type(namespace, data_type, cls) return cls if container_cls is None: return _dec else: _dec(container_cls) # a function to register an object mapper for a container class @docval({"name": "container_cls", "type": type, "doc": "the Container class for which the given ObjectMapper class gets used for"}, {"name": "mapper_cls", "type": type, "doc": "the ObjectMapper class to use to map", 'default': None}, is_method=False) def register_map(**kwargs): """Register an ObjectMapper to use for a Container class type If mapper_cls is not specified, returns a decorator for registering an ObjectMapper class as the mapper for container_cls. If mapper_cls specified, register the class as the mapper for container_cls """ container_cls, mapper_cls = getargs('container_cls', 'mapper_cls', kwargs) def _dec(cls): __TYPE_MAP.register_map(container_cls, cls) return cls if mapper_cls is None: return _dec else: _dec(mapper_cls) def __get_resources(): from pkg_resources import resource_filename from os.path import join __core_ns_file_name = 'namespace.yaml' ret = dict() ret['namespace_path'] = join(resource_filename(__name__, 'hdmf-common-schema/common'), __core_ns_file_name) return ret def _get_resources(): # LEGACY: Needed to support legacy implementation. return __get_resources() @docval({'name': 'namespace_path', 'type': str, 'doc': 'the path to the YAML with the namespace definition'}, returns="the namespaces loaded from the given file", rtype=tuple, is_method=False) def load_namespaces(**kwargs): ''' Load namespaces from file ''' namespace_path = getargs('namespace_path', kwargs) return __TYPE_MAP.load_namespaces(namespace_path) def available_namespaces(): return __TYPE_MAP.namespace_catalog.namespaces # a function to get the container class for a give type @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to get the Container class for'}, {'name': 'namespace', 'type': str, 'doc': 'the namespace the data_type is defined in'}, is_method=False) def get_class(**kwargs): """Get the class object of the Container subclass corresponding to a given neurdata_type. """ data_type, namespace = getargs('data_type', 'namespace', kwargs) return __TYPE_MAP.get_dt_container_cls(data_type, namespace) # load the hdmf-common namespace __resources = __get_resources() if os.path.exists(__resources['namespace_path']): __TYPE_MAP = TypeMap(NamespaceCatalog()) load_namespaces(__resources['namespace_path']) # import these so the TypeMap gets populated from . import io as __io # noqa: F401,E402 from . import table # noqa: F401,E402 from . import alignedtable # noqa: F401,E402 from . import sparse # noqa: F401,E402 from . import resources # noqa: F401,E402 from . import multi # noqa: F401,E402 # register custom class generators from .io.table import DynamicTableGenerator __TYPE_MAP.register_generator(DynamicTableGenerator) from .. import Data, Container __TYPE_MAP.register_container_type(CORE_NAMESPACE, 'Container', Container) __TYPE_MAP.register_container_type(CORE_NAMESPACE, 'Data', Data) else: raise RuntimeError("Unable to load a TypeMap - no namespace file found") DynamicTable = get_class('DynamicTable', CORE_NAMESPACE) VectorData = get_class('VectorData', CORE_NAMESPACE) VectorIndex = get_class('VectorIndex', CORE_NAMESPACE) ElementIdentifiers = get_class('ElementIdentifiers', CORE_NAMESPACE) DynamicTableRegion = get_class('DynamicTableRegion', CORE_NAMESPACE) EnumData = get_class('EnumData', EXP_NAMESPACE) CSRMatrix = get_class('CSRMatrix', CORE_NAMESPACE) ExternalResources = get_class('ExternalResources', EXP_NAMESPACE) SimpleMultiContainer = get_class('SimpleMultiContainer', CORE_NAMESPACE) AlignedDynamicTable = get_class('AlignedDynamicTable', CORE_NAMESPACE) @docval({'name': 'extensions', 'type': (str, TypeMap, list), 'doc': 'a path to a namespace, a TypeMap, or a list consisting paths to namespaces and TypeMaps', 'default': None}, returns="the namespaces loaded from the given file", rtype=tuple, is_method=False) def get_type_map(**kwargs): ''' Get a BuildManager to use for I/O using the given extensions. If no extensions are provided, return a BuildManager that uses the core namespace ''' extensions = getargs('extensions', kwargs) type_map = None if extensions is None: type_map = deepcopy(__TYPE_MAP) else: if isinstance(extensions, TypeMap): type_map = extensions else: type_map = deepcopy(__TYPE_MAP) if isinstance(extensions, list): for ext in extensions: if isinstance(ext, str): type_map.load_namespaces(ext) elif isinstance(ext, TypeMap): type_map.merge(ext) else: msg = 'extensions must be a list of paths to namespace specs or a TypeMaps' raise ValueError(msg) elif isinstance(extensions, str): type_map.load_namespaces(extensions) elif isinstance(extensions, TypeMap): type_map.merge(extensions) return type_map @docval({'name': 'extensions', 'type': (str, TypeMap, list), 'doc': 'a path to a namespace, a TypeMap, or a list consisting paths to namespaces and TypeMaps', 'default': None}, returns="the namespaces loaded from the given file", rtype=tuple, is_method=False) def get_manager(**kwargs): ''' Get a BuildManager to use for I/O using the given extensions. If no extensions are provided, return a BuildManager that uses the core namespace ''' type_map = call_docval_func(get_type_map, kwargs) return BuildManager(type_map) @docval({'name': 'io', 'type': HDMFIO, 'doc': 'the HDMFIO object to read from'}, {'name': 'namespace', 'type': str, 'doc': 'the namespace to validate against', 'default': CORE_NAMESPACE}, {'name': 'experimental', 'type': bool, 'doc': 'data type is an experimental data type', 'default': False}, returns="errors in the file", rtype=list, is_method=False) def validate(**kwargs): """Validate an file against a namespace""" io, namespace, experimental = getargs('io', 'namespace', 'experimental', kwargs) if experimental: namespace = EXP_NAMESPACE builder = io.read_builder() validator = ValidatorMap(io.manager.namespace_catalog.get_namespace(name=namespace)) return validator.validate(builder) @docval(*get_docval(HDF5IO.__init__), is_method=False) def get_hdf5io(**kwargs): """ A convenience method for getting an HDF5IO object """ manager = getargs('manager', kwargs) if manager is None: kwargs['manager'] = get_manager() cargs, ckwargs = fmt_docval_args(HDF5IO.__init__, kwargs) return HDF5IO(*cargs, **ckwargs) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/alignedtable.py0000644000655200065520000005463300000000000020556 0ustar00circlecicircleci""" Collection of Container classes for interacting with aligned and hierarchical dynamic tables """ from collections import OrderedDict import numpy as np import pandas as pd from . import register_class from .table import DynamicTable from ..utils import docval, getargs, call_docval_func, popargs, get_docval @register_class('AlignedDynamicTable') class AlignedDynamicTable(DynamicTable): """ DynamicTable container that supports storing a collection of subtables. Each sub-table is a DynamicTable itself that is aligned with the main table by row index. I.e., all DynamicTables stored in this group MUST have the same number of rows. This type effectively defines a 2-level table in which the main data is stored in the main table implemented by this type and additional columns of the table are grouped into categories, with each category being' represented by a separate DynamicTable stored within the group. NOTE: To remain compatible with DynamicTable, the attribute colnames represents only the columns of the main table (not including the category tables). To get the full list of column names, use the get_colnames() function instead. """ __fields__ = ({'name': 'category_tables', 'child': True}, ) @docval(*get_docval(DynamicTable.__init__), {'name': 'category_tables', 'type': list, 'doc': 'List of DynamicTables to be added to the container. NOTE: Only regular ' 'DynamicTables are allowed. Using AlignedDynamicTable as a category for ' 'AlignedDynamicTable is currently not supported.', 'default': None}, {'name': 'categories', 'type': 'array_data', 'doc': 'List of names with the ordering of category tables', 'default': None}) def __init__(self, **kwargs): # noqa: C901 in_category_tables = popargs('category_tables', kwargs) in_categories = popargs('categories', kwargs) if in_category_tables is not None: # Error check to make sure that all category_table are regular DynamicTable for i, v in enumerate(in_category_tables): if not isinstance(v, DynamicTable): raise ValueError("Category table with index %i is not a DynamicTable" % i) if isinstance(v, AlignedDynamicTable): raise ValueError("Category table with index %i is an AlignedDynamicTable. " "Nesting of AlignedDynamicTable is currently not supported." % i) # set in_categories from the in_category_tables if it is empy if in_categories is None and in_category_tables is not None: in_categories = [tab.name for tab in in_category_tables] # check that if categories is given that we also have category_tables if in_categories is not None and in_category_tables is None: raise ValueError("Categories provided but no category_tables given") # at this point both in_categories and in_category_tables should either both be None or both be a list if in_categories is not None: if len(in_categories) != len(in_category_tables): raise ValueError("%s category_tables given but %s categories specified" % (len(in_category_tables), len(in_categories))) # Initialize the main dynamic table call_docval_func(super().__init__, kwargs) # Create and set all sub-categories dts = OrderedDict() # Add the custom categories given as inputs if in_category_tables is not None: # We may need to resize our main table when adding categories as the user may not have set ids if len(in_category_tables) > 0: # We have categories to process if len(self.id) == 0: # The user did not initialize our main table id's nor set columns for our main table for i in range(len(in_category_tables[0])): self.id.append(i) # Add the user-provided categories in the correct order as described by the categories # This is necessary, because we do not store the categories explicitly but we maintain them # as the order of our self.category_tables. In this makes sure look-ups are consistent. lookup_index = OrderedDict([(k, -1) for k in in_categories]) for i, v in enumerate(in_category_tables): # Error check that the name of the table is in our categories list if v.name not in lookup_index: raise ValueError("DynamicTable %s does not appear in categories %s" % (v.name, str(in_categories))) # Error check to make sure no two tables with the same name are given if lookup_index[v.name] >= 0: raise ValueError("Duplicate table name %s found in input dynamic_tables" % v.name) lookup_index[v.name] = i for table_name, tabel_index in lookup_index.items(): # This error case should not be able to occur since the length of the in_categories and # in_category_tables must match and we made sure that each DynamicTable we added had its # name in the in_categories list. We, therefore, exclude this check from coverage testing # but we leave it in just as a backup trigger in case something unexpected happens if tabel_index < 0: # pragma: no cover raise ValueError("DynamicTable %s listed in categories but does not appear in category_tables" % table_name) # pragma: no cover # Test that all category tables have the correct number of rows category = in_category_tables[tabel_index] if len(category) != len(self): raise ValueError('Category DynamicTable %s does not align, it has %i rows expected %i' % (category.name, len(category), len(self))) # Add the category table to our category_tables. dts[category.name] = category # Set the self.category_tables attribute, which will set the parent/child relationships for the category_tables self.category_tables = dts def __contains__(self, val): """ Check if the given value (i.e., column) exists in this table :param val: If val is a string then check if the given category exists. If val is a tuple of two strings (category, colname) then check for the given category if the given colname exists. """ if isinstance(val, str): return val in self.category_tables or val in self.colnames elif isinstance(val, tuple): if len(val) != 2: raise ValueError("Expected tuple of strings of length 2 got tuple of length %i" % len(val)) return val[1] in self.get_category(val[0]) else: return False @property def categories(self): """ Get the list of names the categories Short-hand for list(self.category_tables.keys()) :raises: KeyError if the given name is not in self.category_tables """ return list(self.category_tables.keys()) @docval({'name': 'category', 'type': DynamicTable, 'doc': 'Add a new DynamicTable category'},) def add_category(self, **kwargs): """ Add a new DynamicTable to the AlignedDynamicTable to create a new category in the table. NOTE: The table must align with (i.e, have the same number of rows as) the main data table (and other category tables). I.e., if the AlignedDynamicTable is already populated with data then we have to populate the new category with the corresponding data before adding it. :raises: ValueError is raised if the input table does not have the same number of rows as the main table. ValueError is raised if the table is an AlignedDynamicTable instead of regular DynamicTable. """ category = getargs('category', kwargs) if len(category) != len(self): raise ValueError('New category DynamicTable does not align, it has %i rows expected %i' % (len(category), len(self))) if category.name in self.category_tables: raise ValueError("Category %s already in the table" % category.name) if isinstance(category, AlignedDynamicTable): raise ValueError("Category is an AlignedDynamicTable. Nesting of AlignedDynamicTable " "is currently not supported.") self.category_tables[category.name] = category category.parent = self @docval({'name': 'name', 'type': str, 'doc': 'Name of the category we want to retrieve', 'default': None}) def get_category(self, **kwargs): name = popargs('name', kwargs) if name is None or (name not in self.category_tables and name == self.name): return self else: return self.category_tables[name] @docval(*get_docval(DynamicTable.add_column), {'name': 'category', 'type': str, 'doc': 'The category the column should be added to', 'default': None}) def add_column(self, **kwargs): """ Add a column to the table :raises: KeyError if the category does not exist """ category_name = popargs('category', kwargs) if category_name is None: # Add the column to our main table call_docval_func(super().add_column, kwargs) else: # Add the column to a sub-category table try: category = self.get_category(category_name) except KeyError: raise KeyError("Category %s not in table" % category_name) category.add_column(**kwargs) @docval({'name': 'data', 'type': dict, 'doc': 'the data to put in this row', 'default': None}, {'name': 'id', 'type': int, 'doc': 'the ID for the row', 'default': None}, {'name': 'enforce_unique_id', 'type': bool, 'doc': 'enforce that the id in the table must be unique', 'default': False}, allow_extra=True) def add_row(self, **kwargs): """ We can either provide the row data as a single dict or by specifying a dict for each category """ data, row_id, enforce_unique_id = popargs('data', 'id', 'enforce_unique_id', kwargs) data = data if data is not None else kwargs # extract the category data category_data = {k: data.pop(k) for k in self.categories if k in data} # Check that we have the approbriate categories provided missing_categories = set(self.categories) - set(list(category_data.keys())) if missing_categories: raise KeyError( '\n'.join([ 'row data keys do not match available categories', 'missing {} category keys: {}'.format(len(missing_categories), missing_categories) ]) ) # Add the data to our main dynamic table data['id'] = row_id data['enforce_unique_id'] = enforce_unique_id call_docval_func(super().add_row, data) # Add the data to all out dynamic table categories for category, values in category_data.items(): self.category_tables[category].add_row(**values) @docval({'name': 'include_category_tables', 'type': bool, 'doc': "Ignore sub-category tables and just look at the main table", 'default': False}, {'name': 'ignore_category_ids', 'type': bool, 'doc': "Ignore id columns of sub-category tables", 'default': False}) def get_colnames(self, **kwargs): """Get the full list of names of columns for this table :returns: List of tuples (str, str) where the first string is the name of the DynamicTable that contains the column and the second string is the name of the column. If include_category_tables is False, then a list of column names is returned. """ if not getargs('include_category_tables', kwargs): return self.colnames else: ignore_category_ids = getargs('ignore_category_ids', kwargs) columns = [(self.name, c) for c in self.colnames] for category in self.category_tables.values(): if not ignore_category_ids: columns += [(category.name, 'id'), ] columns += [(category.name, c) for c in category.colnames] return columns @docval({'name': 'ignore_category_ids', 'type': bool, 'doc': "Ignore id columns of sub-category tables", 'default': False}) def to_dataframe(self, **kwargs): """Convert the collection of tables to a single pandas DataFrame""" dfs = [super().to_dataframe().reset_index(), ] if getargs('ignore_category_ids', kwargs): dfs += [category.to_dataframe() for category in self.category_tables.values()] else: dfs += [category.to_dataframe().reset_index() for category in self.category_tables.values()] names = [self.name, ] + list(self.category_tables.keys()) res = pd.concat(dfs, axis=1, keys=names) res.set_index((self.name, 'id'), drop=True, inplace=True) return res def __getitem__(self, item): """ Called to implement standard array slicing syntax. Same as ``self.get(item)``. See :py:meth:`~hdmf.common.alignedtable.AlignedDynamicTable.get` for details. """ return self.get(item) def get(self, item, **kwargs): """ Access elements (rows, columns, category tables etc.) from the table. Instead of calling this function directly, the class also implements standard array slicing syntax via :py:meth:`~hdmf.common.alignedtable.AlignedDynamicTable.__getitem__` (which calls this function). For example, instead of calling ``self.get(item=slice(2,5))`` we may use the often more convenient form of ``self[2:5]`` instead. :param item: Selection defining the items of interest. This may be either a: * **int, list, array, slice** : Return one or multiple row of the table as a pandas.DataFrame. For example: * ``self[0]`` : Select the first row of the table * ``self[[0,3]]`` : Select the first and fourth row of the table * ``self[1:4]`` : Select the rows with index 1,2,3 from the table * **string** : Return a column from the main table or a category table. For example: * ``self['column']`` : Return the column from the main table. * ``self['my_category']`` : Returns a DataFrame of the ``my_category`` category table. This is a shorthand for ``self.get_category('my_category').to_dataframe()``. * **tuple**: Get a column, row, or cell from a particular category table. The tuple is expected to consist of the following elements: * ``category``: string with the name of the category. To select from the main table use ``self.name`` or ``None``. * ``column``: string with the name of the column, and * ``row``: integer index of the row. The tuple itself then may take the following forms: * Select a single column from a table via: * ``self[category, column]`` * Select a single full row of a given category table via: * ``self[row, category]`` (recommended, for consistency with DynamicTable) * ``self[category, row]`` * Select a single cell via: * ``self[row, (category, column)]`` (recommended, for consistency with DynamicTable) * ``self[row, category, column]`` * ``self[category, column, row]`` :returns: Depending on the type of selection the function returns a: * **pandas.DataFrame**: when retrieving a row or category table * **array** : when retrieving a single column * **single value** : when retrieving a single cell. The data type and shape will depend on the data type and shape of the cell/column. """ if isinstance(item, (int, list, np.ndarray, slice)): # get a single full row from all tables dfs = ([super().get(item, **kwargs).reset_index(), ] + [category[item].reset_index() for category in self.category_tables.values()]) names = [self.name, ] + list(self.category_tables.keys()) res = pd.concat(dfs, axis=1, keys=names) res.set_index((self.name, 'id'), drop=True, inplace=True) return res elif isinstance(item, str) or item is None: if item in self.colnames: # get a specific column return super().get(item, **kwargs) else: # get a single category return self.get_category(item).to_dataframe() elif isinstance(item, tuple): if len(item) == 2: # DynamicTable allows selection of cells via the syntax [int, str], i.e,. [row_index, columnname] # We support this syntax here as well with the additional caveat that in AlignedDynamicTable # columns are identified by tuples of strings. As such [int, str] refers not to a cell but # a single row in a particular category table (i.e., [row_index, category]). To select a cell # the second part of the item then is a tuple of strings, i.e., [row_index, (category, column)] if isinstance(item[0], (int, np.integer)): # Select a single cell or row of a sub-table based on row-index(item[0]) # and the category (if item[1] is a string) or column (if item[1] is a tuple of (category, column) re = self[item[0]][item[1]] # re is a pandas.Series or pandas.Dataframe. If we selected a single cell # (i.e., item[2] was a tuple defining a particular column) then return the value of the cell if re.size == 1: re = re.values[0] # If we selected a single cell from a ragged column then we need to change the list to a tuple if isinstance(re, list): re = tuple(re) # We selected a row of a whole table (i.e., item[2] identified only the category table, # but not a particular column). # Change the result from a pandas.Series to a pandas.DataFrame for consistency with DynamicTable if isinstance(re, pd.Series): re = re.to_frame() return re else: return self.get_category(item[0])[item[1]] elif len(item) == 3: if isinstance(item[0], (int, np.integer)): return self.get_category(item[1])[item[2]][item[0]] else: return self.get_category(item[0])[item[1]][item[2]] else: raise ValueError("Expected tuple of length 2 of the form [category, column], [row, category], " "[row, (category, column)] or a tuple of length 3 of the form " "[category, column, row], [row, category, column]") @docval({'name': 'ignore_category_tables', 'type': bool, 'doc': "Ignore the category tables and only check in the main table columns", 'default': False}, allow_extra=False) def has_foreign_columns(self, **kwargs): """ Does the table contain DynamicTableRegion columns :returns: True if the table or any of the category tables contains a DynamicTableRegion column, else False """ ignore_category_tables = getargs('ignore_category_tables', kwargs) if super().has_foreign_columns(): return True if not ignore_category_tables: for table in self.category_tables.values(): if table.has_foreign_columns(): return True return False @docval({'name': 'ignore_category_tables', 'type': bool, 'doc': "Ignore the category tables and only check in the main table columns", 'default': False}, allow_extra=False) def get_foreign_columns(self, **kwargs): """ Determine the names of all columns that link to another DynamicTable, i.e., find all DynamicTableRegion type columns. Similar to a foreign key in a database, a DynamicTableRegion column references elements in another table. :returns: List of tuples (str, str) where the first string is the name of the category table (or None if the column is in the main table) and the second string is the column name. """ ignore_category_tables = getargs('ignore_category_tables', kwargs) col_names = [(None, col_name) for col_name in super().get_foreign_columns()] if not ignore_category_tables: for table in self.category_tables.values(): col_names += [(table.name, col_name) for col_name in table.get_foreign_columns()] return col_names @docval(*get_docval(DynamicTable.get_linked_tables), {'name': 'ignore_category_tables', 'type': bool, 'doc': "Ignore the category tables and only check in the main table columns", 'default': False}, allow_extra=False) def get_linked_tables(self, **kwargs): """ Get a list of the full list of all tables that are being linked to directly or indirectly from this table via foreign DynamicTableColumns included in this table or in any table that can be reached through DynamicTableRegion columns Returns: List of dicts with the following keys: * 'source_table' : The source table containing the DynamicTableRegion column * 'source_column' : The relevant DynamicTableRegion column in the 'source_table' * 'target_table' : The target DynamicTable; same as source_column.table. """ ignore_category_tables = getargs('ignore_category_tables', kwargs) other_tables = None if ignore_category_tables else list(self.category_tables.values()) return super().get_linked_tables(other_tables=other_tables) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1686273 hdmf-3.1.1/src/hdmf/common/hdmf-common-schema/0000755000655200065520000000000000000000000021220 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1627603655.180627 hdmf-3.1.1/src/hdmf/common/hdmf-common-schema/common/0000755000655200065520000000000000000000000022510 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603599.0 hdmf-3.1.1/src/hdmf/common/hdmf-common-schema/common/base.yaml0000644000655200065520000000120400000000000024303 0ustar00circlecicircleci# hdmf-schema-language=2.0.2 datasets: - data_type_def: Data doc: An abstract data type for a dataset. groups: - data_type_def: Container doc: An abstract data type for a group storing collections of data and metadata. Base type for all data and metadata containers. - data_type_def: SimpleMultiContainer data_type_inc: Container doc: A simple Container for holding onto multiple containers. datasets: - data_type_inc: Data quantity: '*' doc: Data objects held within this SimpleMultiContainer. groups: - data_type_inc: Container quantity: '*' doc: Container objects held within this SimpleMultiContainer. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603599.0 hdmf-3.1.1/src/hdmf/common/hdmf-common-schema/common/experimental.yaml0000644000655200065520000000065500000000000026077 0ustar00circlecicirclecigroups: [] datasets: - data_type_def: EnumData data_type_inc: VectorData dtype: uint8 doc: Data that come from a fixed set of values. A data value of i corresponds to the i-th value in the VectorData referenced by the 'elements' attribute. attributes: - name: elements dtype: target_type: VectorData reftype: object doc: Reference to the VectorData object that contains the enumerable elements ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603599.0 hdmf-3.1.1/src/hdmf/common/hdmf-common-schema/common/namespace.yaml0000644000655200065520000000230600000000000025331 0ustar00circlecicircleci# hdmf-schema-language=2.0.2 namespaces: - name: hdmf-common doc: Common data structures provided by HDMF author: - Andrew Tritt - Oliver Ruebel - Ryan Ly - Ben Dichter contact: - ajtritt@lbl.gov - oruebel@lbl.gov - rly@lbl.gov - bdichter@lbl.gov full_name: HDMF Common schema: - doc: base data types source: base.yaml title: Base data types - doc: data types for a column-based table source: table.yaml title: Table data types - doc: data types for different types of sparse matrices source: sparse.yaml title: Sparse data types version: 1.5.0 - name: hdmf-experimental doc: Experimental data structures provided by HDMF. These are not guaranteed to be available in the future author: - Andrew Tritt - Oliver Ruebel - Ryan Ly - Ben Dichter contact: - ajtritt@lbl.gov - oruebel@lbl.gov - rly@lbl.gov - bdichter@lbl.gov full_name: HDMF Experimental schema: - namespace: hdmf-common - doc: Experimental data types source: experimental.yaml title: Experimental data types - doc: data types for storing references to web accessible resources source: resources.yaml title: Resource reference data types version: 0.1.0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603599.0 hdmf-3.1.1/src/hdmf/common/hdmf-common-schema/common/resources.yaml0000644000655200065520000000475300000000000025417 0ustar00circlecicircleci# hdmf-schema-language=2.0.2 groups: - data_type_def: ExternalResources data_type_inc: Container doc: "A set of four tables for tracking external resource references in a file. NOTE: this data type is in beta testing and is subject to change in a later version." datasets: - data_type_inc: Data name: keys doc: A table for storing user terms that are used to refer to external resources. dtype: - name: key dtype: text doc: The user term that maps to one or more resources in the 'resources' table. dims: - num_rows shape: - null - data_type_inc: Data name: entities doc: A table for mapping user terms (i.e., keys) to resource entities. dtype: - name: keys_idx dtype: uint doc: The index to the key in the 'keys' table. - name: resources_idx dtype: uint doc: The index into the 'resources' table - name: entity_id dtype: text doc: The unique identifier entity. - name: entity_uri dtype: text doc: The URI for the entity this reference applies to. This can be an empty string. dims: - num_rows shape: - null - data_type_inc: Data name: resources doc: A table for mapping user terms (i.e., keys) to resource entities. dtype: - name: resource dtype: text doc: The name of the resource. - name: resource_uri dtype: text doc: The URI for the resource. This can be an empty string. dims: - num_rows shape: - null - data_type_inc: Data name: objects doc: A table for identifying which objects in a file contain references to external resources. dtype: - name: object_id dtype: text doc: The UUID for the object. - name: field dtype: text doc: The field of the object. This can be an empty string if the object is a dataset and the field is the dataset values. dims: - num_rows shape: - null - data_type_inc: Data name: object_keys doc: A table for identifying which objects use which keys. dtype: - name: objects_idx dtype: uint doc: The index to the 'objects' table for the object that holds the key. - name: keys_idx dtype: uint doc: The index to the 'keys' table for the key. dims: - num_rows shape: - null ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603599.0 hdmf-3.1.1/src/hdmf/common/hdmf-common-schema/common/sparse.yaml0000644000655200065520000000163100000000000024672 0ustar00circlecicircleci# hdmf-schema-language=2.0.2 groups: - data_type_def: CSRMatrix data_type_inc: Container doc: 'A compressed sparse row matrix. Data are stored in the standard CSR format, where column indices for row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored in data[indptr[i]:indptr[i+1]].' attributes: - name: shape dtype: uint dims: - number of rows, number of columns shape: - 2 doc: The shape (number of rows, number of columns) of this sparse matrix. datasets: - name: indices dtype: uint dims: - number of non-zero values shape: - null doc: The column indices. - name: indptr dtype: uint dims: - number of rows in the matrix + 1 shape: - null doc: The row index pointer. - name: data dims: - number of non-zero values shape: - null doc: The non-zero values in the matrix. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603599.0 hdmf-3.1.1/src/hdmf/common/hdmf-common-schema/common/table.yaml0000644000655200065520000001477400000000000024500 0ustar00circlecicircleci# hdmf-schema-language=2.0.2 datasets: - data_type_def: VectorData data_type_inc: Data doc: An n-dimensional dataset representing a column of a DynamicTable. If used without an accompanying VectorIndex, first dimension is along the rows of the DynamicTable and each step along the first dimension is a cell of the larger table. VectorData can also be used to represent a ragged array if paired with a VectorIndex. This allows for storing arrays of varying length in a single cell of the DynamicTable by indexing into this VectorData. The first vector is at VectorData[0:VectorIndex[0]]. The second vector is at VectorData[VectorIndex[0]:VectorIndex[1]], and so on. dims: - - dim0 - - dim0 - dim1 - - dim0 - dim1 - dim2 - - dim0 - dim1 - dim2 - dim3 shape: - - null - - null - null - - null - null - null - - null - null - null - null attributes: - name: description dtype: text doc: Description of what these vectors represent. - data_type_def: VectorIndex data_type_inc: VectorData dtype: uint8 doc: Used with VectorData to encode a ragged array. An array of indices into the first dimension of the target VectorData, and forming a map between the rows of a DynamicTable and the indices of the VectorData. The name of the VectorIndex is expected to be the name of the target VectorData object followed by "_index". dims: - num_rows shape: - null attributes: - name: target dtype: target_type: VectorData reftype: object doc: Reference to the target dataset that this index applies to. - data_type_def: ElementIdentifiers data_type_inc: Data default_name: element_id dtype: int dims: - num_elements shape: - null doc: A list of unique identifiers for values within a dataset, e.g. rows of a DynamicTable. - data_type_def: DynamicTableRegion data_type_inc: VectorData dtype: int doc: DynamicTableRegion provides a link from one table to an index or region of another. The `table` attribute is a link to another `DynamicTable`, indicating which table is referenced, and the data is int(s) indicating the row(s) (0-indexed) of the target array. `DynamicTableRegion`s can be used to associate rows with repeated meta-data without data duplication. They can also be used to create hierarchical relationships between multiple `DynamicTable`s. `DynamicTableRegion` objects may be paired with a `VectorIndex` object to create ragged references, so a single cell of a `DynamicTable` can reference many rows of another `DynamicTable`. dims: - num_rows shape: - null attributes: - name: table dtype: target_type: DynamicTable reftype: object doc: Reference to the DynamicTable object that this region applies to. - name: description dtype: text doc: Description of what this table region points to. groups: - data_type_def: DynamicTable data_type_inc: Container doc: A group containing multiple datasets that are aligned on the first dimension (Currently, this requirement if left up to APIs to check and enforce). These datasets represent different columns in the table. Apart from a column that contains unique identifiers for each row, there are no other required datasets. Users are free to add any number of custom VectorData objects (columns) here. DynamicTable also supports ragged array columns, where each element can be of a different size. To add a ragged array column, use a VectorIndex type to index the corresponding VectorData type. See documentation for VectorData and VectorIndex for more details. Unlike a compound data type, which is analogous to storing an array-of-structs, a DynamicTable can be thought of as a struct-of-arrays. This provides an alternative structure to choose from when optimizing storage for anticipated access patterns. Additionally, this type provides a way of creating a table without having to define a compound type up front. Although this convenience may be attractive, users should think carefully about how data will be accessed. DynamicTable is more appropriate for column-centric access, whereas a dataset with a compound type would be more appropriate for row-centric access. Finally, data size should also be taken into account. For small tables, performance loss may be an acceptable trade-off for the flexibility of a DynamicTable. attributes: - name: colnames dtype: text dims: - num_columns shape: - null doc: The names of the columns in this table. This should be used to specify an order to the columns. - name: description dtype: text doc: Description of what is in this dynamic table. datasets: - name: id data_type_inc: ElementIdentifiers dtype: int dims: - num_rows shape: - null doc: Array of unique identifiers for the rows of this dynamic table. - data_type_inc: VectorData doc: Vector columns, including index columns, of this dynamic table. quantity: '*' - data_type_def: AlignedDynamicTable data_type_inc: DynamicTable doc: DynamicTable container that supports storing a collection of sub-tables. Each sub-table is a DynamicTable itself that is aligned with the main table by row index. I.e., all DynamicTables stored in this group MUST have the same number of rows. This type effectively defines a 2-level table in which the main data is stored in the main table implemented by this type and additional columns of the table are grouped into categories, with each category being represented by a separate DynamicTable stored within the group. attributes: - name: categories dtype: text dims: - num_categories shape: - null doc: The names of the categories in this AlignedDynamicTable. Each category is represented by one DynamicTable stored in the parent group. This attribute should be used to specify an order of categories and the category names must match the names of the corresponding DynamicTable in the group. groups: - data_type_inc: DynamicTable doc: A DynamicTable representing a particular category for columns in the AlignedDynamicTable parent container. The table MUST be aligned with (i.e., have the same number of rows) as all other DynamicTables stored in the AlignedDynamicTable parent container. The name of the category is given by the name of the DynamicTable and its description by the description attribute of the DynamicTable. quantity: '*' ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/hierarchicaltable.py0000644000655200065520000003225600000000000021566 0ustar00circlecicircleci""" Module providing additional functionality for dealing with hierarchically nested tables, i.e., tables containing DynamicTableRegion references. """ import pandas as pd import numpy as np from hdmf.common.table import DynamicTable, DynamicTableRegion, VectorIndex from hdmf.common.alignedtable import AlignedDynamicTable from hdmf.utils import docval, getargs @docval({'name': 'dynamic_table', 'type': DynamicTable, 'doc': 'DynamicTable object to be converted to a hierarchical pandas.Dataframe'}, returns="Hierarchical pandas.DataFrame with usually a pandas.MultiIndex on both the index and columns.", rtype='pandas.DataFrame', is_method=False) def to_hierarchical_dataframe(dynamic_table): """ Create a hierarchical pandas.DataFrame that represents all data from a collection of linked DynamicTables. **LIMITATIONS:** Currently this function only supports DynamicTables with a single DynamicTableRegion column. If a table has more than one DynamicTableRegion column then the function will expand only the first DynamicTableRegion column found for each table. Any additional DynamicTableRegion columns will remain nested. **NOTE:** Some useful functions for further processing of the generated DataFrame include: * pandas.DataFrame.reset_index to turn the data from the pandas.MultiIndex into columns * :py:meth:`~hdmf.common.hierarchicaltable.drop_id_columns` to remove all 'id' columns * :py:meth:`~hdmf.common.hierarchicaltable.flatten_column_index` to flatten the column index """ # TODO: Need to deal with the case where we have more than one DynamicTableRegion column in a given table # Get the references column foreign_columns = dynamic_table.get_foreign_columns() # if table does not contain any DynamicTableRegion columns then we can just convert it to a dataframe if len(foreign_columns) == 0: return dynamic_table.to_dataframe() hcol_name = foreign_columns[0] # We only denormalize the first foreign column for now hcol = dynamic_table[hcol_name] # Either a VectorIndex pointing to a DynamicTableRegion or a DynamicTableRegion # Get the target DynamicTable that hcol is pointing to. If hcol is a VectorIndex then we first need # to get the target of it before we look up the table. hcol_target = hcol.table if isinstance(hcol, DynamicTableRegion) else hcol.target.table # Create the data variables we need to collect the data for our output dataframe and associated index index = [] data = [] columns = None index_names = None # First we here get a list of DataFrames, one for each row of the column we need to process. # If hcol is a VectorIndex (i.e., our column is a ragged array of row indices), then simply loading # the data from the VectorIndex will do the trick. If we have a regular DynamicTableRegion column, # then we need to load the elements ourselves (using slice syntax to make sure we get DataFrames) # one-row-at-a-time if isinstance(hcol, VectorIndex): rows = hcol.get(slice(None), index=False, df=True) else: rows = [hcol[i:(i+1)] for i in range(len(hcol))] # Retrieve the columns we need to iterate over from our input table. For AlignedDynamicTable we need to # use the get_colnames function instead of the colnames property to ensure we get all columns not just # the columns from the main table dynamic_table_colnames = (dynamic_table.get_colnames(include_category_tables=True, ignore_category_ids=False) if isinstance(dynamic_table, AlignedDynamicTable) else dynamic_table.colnames) # Case 1: Our DynamicTableRegion column points to a DynamicTable that itself does not contain # any DynamicTableRegion references (i.e., we have reached the end of our table hierarchy). # If this is the case than we need to de-normalize the data and flatten the hierarchy if not hcol_target.has_foreign_columns(): # Iterate over all rows, where each row is described by a DataFrame with one-or-more rows for row_index, row_df in enumerate(rows): # Since each row contains a pandas.DataFrame (with possible multiple rows), we # next need to iterate over all rows in that table to denormalize our data for row in row_df.itertuples(index=True): # Determine the column data for our row. Each selected row from our target table # becomes a row in our flattened table data.append(row) # Determine the multi-index tuple for our row, consisting of: i) id of the row in this # table, ii) all columns (except the hierarchical column we are flattening), and # iii) the index (i.e., id) from our target row index_data = ([dynamic_table.id[row_index], ] + [dynamic_table[row_index, colname] for colname in dynamic_table_colnames if colname != hcol_name]) index.append(tuple(index_data)) # Determine the names for our index and columns of our output table # We need to do this even if our table was empty (i.e. even is len(rows)==0) # NOTE: While for a regular DynamicTable the "colnames" property will give us the full list of column names, # for AlignedDynamicTable we need to use the get_colnames() function instead to make sure we include # the category table columns as well. index_names = ([(dynamic_table.name, 'id')] + [(dynamic_table.name, colname) for colname in dynamic_table_colnames if colname != hcol_name]) # Determine the name of our columns hcol_iter_columns = (hcol_target.get_colnames(include_category_tables=True, ignore_category_ids=False) if isinstance(hcol_target, AlignedDynamicTable) else hcol_target.colnames) columns = pd.MultiIndex.from_tuples([(hcol_target.name, 'id'), ] + [(hcol_target.name, c) for c in hcol_iter_columns], names=('source_table', 'label')) # Case 2: Our DynamicTableRegion columns points to another table with a DynamicTableRegion, i.e., # we need to recursively resolve more levels of the table hieararchy else: # First we need to recursively flatten the hierarchy by calling 'to_hierarchical_dataframe()' # (i.e., this function) on the target of our hierarchical column hcol_hdf = to_hierarchical_dataframe(hcol_target) # Iterate over all rows, where each row is described by a DataFrame with one-or-more rows for row_index, row_df_level1 in enumerate(rows): # Since each row contains a pandas.DataFrame (with possible multiple rows), we # next need to iterate over all rows in that table to denormalize our data for row_df_level2 in row_df_level1.itertuples(index=True): # Since our target is itself a a DynamicTable with a DynamicTableRegion columns, # each target row itself may expand into multiple rows in the flattened hcol_hdf. # So we now need to look up the rows in hcol_hdf that correspond to the rows in # row_df_level2. # NOTE: In this look-up we assume that the ids (and hence the index) of # each row in the table are in fact unique. for row_tuple_level3 in hcol_hdf.loc[[row_df_level2[0]]].itertuples(index=True): # Determine the column data for our row. data.append(row_tuple_level3[1:]) # Determine the multi-index tuple for our row, index_data = ([dynamic_table.id[row_index], ] + [dynamic_table[row_index, colname] for colname in dynamic_table_colnames if colname != hcol_name] + list(row_tuple_level3[0])) index.append(tuple(index_data)) # Determine the names for our index and columns of our output table # We need to do this even if our table was empty (i.e. even is len(rows)==0) index_names = ([(dynamic_table.name, "id")] + [(dynamic_table.name, colname) for colname in dynamic_table_colnames if colname != hcol_name] + hcol_hdf.index.names) columns = hcol_hdf.columns # Construct the pandas dataframe with the hierarchical multi-index multi_index = pd.MultiIndex.from_tuples(index, names=index_names) out_df = pd.DataFrame(data=data, index=multi_index, columns=columns) return out_df def __get_col_name(col): """ Internal helper function to get the actual name of a pandas DataFrame column from a column name that may consists of an arbitrary sequence of tuples. The function will return the last value of the innermost tuple. """ curr_val = col while isinstance(curr_val, tuple): curr_val = curr_val[-1] return curr_val def __flatten_column_name(col): """ Internal helper function used to iteratively flatten a nested tuple :param col: Column name to flatten :type col: Tuple or String :returns: If col is a tuple then the result is a flat tuple otherwise col is returned as is """ if isinstance(col, tuple): re = col while np.any([isinstance(v, tuple) for v in re]): temp = [] for v in re: if isinstance(v, tuple): temp += list(v) else: temp += [v, ] re = temp return tuple(re) else: return col @docval({'name': 'dataframe', 'type': pd.DataFrame, 'doc': 'Pandas dataframe to update (usually generated by the to_hierarchical_dataframe function)'}, {'name': 'inplace', 'type': 'bool', 'doc': 'Update the dataframe inplace or return a modified copy', 'default': False}, returns="pandas.DataFrame with the id columns removed", rtype='pandas.DataFrame', is_method=False) def drop_id_columns(**kwargs): """ Drop all columns named 'id' from the table. In case a column name is a tuple the function will drop any column for which the inner-most name is 'id'. The 'id' columns of DynamicTable is in many cases not necessary for analysis or display. This function allow us to easily filter all those columns. :raises TypeError: In case that dataframe parameter is not a pandas.Dataframe. """ dataframe, inplace = getargs('dataframe', 'inplace', kwargs) col_name = 'id' drop_labels = [] for col in dataframe.columns: if __get_col_name(col) == col_name: drop_labels.append(col) re = dataframe.drop(labels=drop_labels, axis=1, inplace=inplace) return dataframe if inplace else re @docval({'name': 'dataframe', 'type': pd.DataFrame, 'doc': 'Pandas dataframe to update (usually generated by the to_hierarchical_dataframe function)'}, {'name': 'max_levels', 'type': (int, np.integer), 'doc': 'Maximum number of levels to use in the resulting column Index. NOTE: When ' 'limiting the number of levels the function simply removes levels from the ' 'beginning. As such, removing levels may result in columns with duplicate names.' 'Value must be >0.', 'default': None}, {'name': 'inplace', 'type': 'bool', 'doc': 'Update the dataframe inplace or return a modified copy', 'default': False}, returns="pandas.DataFrame with a regular pandas.Index columns rather and a pandas.MultiIndex", rtype='pandas.DataFrame', is_method=False) def flatten_column_index(**kwargs): """ Flatten the column index of a pandas DataFrame. The functions changes the dataframe.columns from a pandas.MultiIndex to a normal Index, with each column usually being identified by a tuple of strings. This function is typically used in conjunction with DataFrames generated by :py:meth:`~hdmf.common.hierarchicaltable.to_hierarchical_dataframe` :raises ValueError: In case the num_levels is not >0 :raises TypeError: In case that dataframe parameter is not a pandas.Dataframe. """ dataframe, max_levels, inplace = getargs('dataframe', 'max_levels', 'inplace', kwargs) if max_levels is not None and max_levels <= 0: raise ValueError('max_levels must be greater than 0') # Compute the new column names col_names = [__flatten_column_name(col) for col in dataframe.columns.values] # Apply the max_levels filter. Make sure to do this only for columns that are actually tuples # in order not to accidentally shorten the actual string name of columns if max_levels is None: select_levels = slice(None) elif max_levels == 1: select_levels = -1 else: # max_levels > 1 select_levels = slice(-max_levels, None) col_names = [col[select_levels] if isinstance(col, tuple) else col for col in col_names] re = dataframe if inplace else dataframe.copy() re.columns = col_names return re ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1627603655.180627 hdmf-3.1.1/src/hdmf/common/io/0000755000655200065520000000000000000000000016165 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/io/__init__.py0000644000655200065520000000022300000000000020273 0ustar00circlecicirclecifrom . import multi # noqa: F401 from . import resources # noqa: F401 from . import table # noqa: F401 from . import alignedtable # noqa: F401 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/io/alignedtable.py0000644000655200065520000000113200000000000021147 0ustar00circlecicirclecifrom .. import register_map from ..alignedtable import AlignedDynamicTable from .table import DynamicTableMap @register_map(AlignedDynamicTable) class AlignedDynamicTableMap(DynamicTableMap): """ Customize the mapping for AlignedDynamicTable """ def __init__(self, spec): super().__init__(spec) # By default the DynamicTables contained as sub-categories in the AlignedDynamicTable are mapped to # the 'dynamic_tables' class attribute. This renames the attribute to 'category_tables' self.map_spec('category_tables', spec.get_data_type('DynamicTable')) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/io/multi.py0000644000655200065520000000166500000000000017701 0ustar00circlecicirclecifrom .. import register_map from ..multi import SimpleMultiContainer from ...build import ObjectMapper from ...container import Container, Data @register_map(SimpleMultiContainer) class SimpleMultiContainerMap(ObjectMapper): @ObjectMapper.object_attr('containers') def containers_attr(self, container, manager): return [c for c in container.containers.values() if isinstance(c, Container)] @ObjectMapper.constructor_arg('containers') def containers_carg(self, builder, manager): return [manager.construct(sub) for sub in builder.datasets.values() if manager.is_sub_data_type(sub, 'Data')] + \ [manager.construct(sub) for sub in builder.groups.values() if manager.is_sub_data_type(sub, 'Container')] @ObjectMapper.object_attr('datas') def datas_attr(self, container, manager): return [c for c in container.containers.values() if isinstance(c, Data)] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/io/resources.py0000644000655200065520000000334400000000000020555 0ustar00circlecicirclecifrom .. import register_map from ..resources import ExternalResources, KeyTable, ResourceTable, ObjectTable, ObjectKeyTable, EntityTable from ...build import ObjectMapper @register_map(ExternalResources) class ExternalResourcesMap(ObjectMapper): def construct_helper(self, name, parent_builder, table_cls, manager): """Create a new instance of table_cls with data from parent_builder[name]. The DatasetBuilder for name is associated with data_type Data and container class Data, but users should use the more specific table_cls for these datasets. """ parent = manager._get_proxy_builder(parent_builder) builder = parent_builder[name] src = builder.source oid = builder.attributes.get(self.spec.id_key()) kwargs = dict(name=builder.name, data=builder.data) return self.__new_container__(table_cls, src, parent, oid, **kwargs) @ObjectMapper.constructor_arg('keys') def keys(self, builder, manager): return self.construct_helper('keys', builder, KeyTable, manager) @ObjectMapper.constructor_arg('resources') def resources(self, builder, manager): return self.construct_helper('resources', builder, ResourceTable, manager) @ObjectMapper.constructor_arg('entities') def entities(self, builder, manager): return self.construct_helper('entities', builder, EntityTable, manager) @ObjectMapper.constructor_arg('objects') def objects(self, builder, manager): return self.construct_helper('objects', builder, ObjectTable, manager) @ObjectMapper.constructor_arg('object_keys') def object_keys(self, builder, manager): return self.construct_helper('object_keys', builder, ObjectKeyTable, manager) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/io/table.py0000644000655200065520000001676600000000000017646 0ustar00circlecicirclecifrom .. import register_map from ..table import DynamicTable, VectorData, VectorIndex, DynamicTableRegion from ...build import ObjectMapper, BuildManager, CustomClassGenerator from ...spec import Spec from ...utils import docval, getargs @register_map(DynamicTable) class DynamicTableMap(ObjectMapper): def __init__(self, spec): super().__init__(spec) vector_data_spec = spec.get_data_type('VectorData') self.map_spec('columns', vector_data_spec) @ObjectMapper.object_attr('colnames') def attr_columns(self, container, manager): if all(not col for col in container.columns): return tuple() return container.colnames @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, {"name": "container", "type": DynamicTable, "doc": "the container to get the attribute value from"}, {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, returns='the value of the attribute') def get_attr_value(self, **kwargs): ''' Get the value of the attribute corresponding to this spec from the given container ''' spec, container, manager = getargs('spec', 'container', 'manager', kwargs) attr_value = super().get_attr_value(spec, container, manager) if attr_value is None and spec.name in container: if spec.data_type_inc == 'VectorData': attr_value = container[spec.name] if isinstance(attr_value, VectorIndex): attr_value = attr_value.target elif spec.data_type_inc == 'DynamicTableRegion': attr_value = container[spec.name] if isinstance(attr_value, VectorIndex): attr_value = attr_value.target if attr_value.table is None: msg = "empty or missing table for DynamicTableRegion '%s' in DynamicTable '%s'" % \ (attr_value.name, container.name) raise ValueError(msg) elif spec.data_type_inc == 'VectorIndex': attr_value = container[spec.name] return attr_value class DynamicTableGenerator(CustomClassGenerator): @classmethod def apply_generator_to_field(cls, field_spec, bases, type_map): """Return True if this is a DynamicTable and the field spec is a column.""" for b in bases: if issubclass(b, DynamicTable): break else: # return False if no base is a subclass of DynamicTable return False dtype = cls._get_type(field_spec, type_map) return isinstance(dtype, type) and issubclass(dtype, VectorData) @classmethod def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): """Add __columns__ to the classdict and update the docval args for the field spec with the given attribute name. :param classdict: The dict to update with __columns__. :param docval_args: The list of docval arguments. :param parent_cls: The parent class. :param attr_name: The attribute name of the field spec for the container class to generate. :param not_inherited_fields: Dictionary of fields not inherited from the parent class. :param type_map: The type map to use. :param spec: The spec for the container class to generate. """ if attr_name.endswith('_index'): # do not add index columns to __columns__ return field_spec = not_inherited_fields[attr_name] column_conf = dict( name=attr_name, description=field_spec['doc'], required=field_spec.required ) dtype = cls._get_type(field_spec, type_map) if issubclass(dtype, DynamicTableRegion): # the spec does not know which table this DTR points to # the user must specify the table attribute on the DTR after it is generated column_conf['table'] = True index_counter = 0 index_name = attr_name while '{}_index'.format(index_name) in not_inherited_fields: # an index column exists for this column index_name = '{}_index'.format(index_name) index_counter += 1 if index_counter == 1: column_conf['index'] = True elif index_counter > 1: column_conf['index'] = index_counter classdict.setdefault('__columns__', list()).append(column_conf) # do not add DynamicTable columns to init docval # add a specialized docval arg for __init__ for specifying targets for DTRs target_tables_dvarg = dict( name='target_tables', doc=('dict mapping DynamicTableRegion column name to the table that the DTR points to. The column is ' 'added to the table if it is not already present (i.e., when it is optional).'), type=dict, default=None ) cls._add_to_docval_args(docval_args, target_tables_dvarg) @classmethod def post_process(cls, classdict, bases, docval_args, spec): """Convert classdict['__columns__'] to tuple. :param classdict: The class dictionary. :param bases: The list of base classes. :param docval_args: The dict of docval arguments. :param spec: The spec for the container class to generate. """ # convert classdict['__columns__'] from list to tuple if present columns = classdict.get('__columns__') if columns is not None: classdict['__columns__'] = tuple(columns) @classmethod def set_init(cls, classdict, bases, docval_args, not_inherited_fields, name): if '__columns__' not in classdict: return base_init = classdict.get('__init__') if base_init is None: # pragma: no cover raise ValueError("Generated class dictionary is missing base __init__ method.") @docval(*docval_args) def __init__(self, **kwargs): base_init(self, **kwargs) # set target attribute on DTR target_tables = kwargs.get('target_tables') if target_tables: for colname, table in target_tables.items(): if colname not in self: # column has not yet been added (it is optional) column_conf = None for conf in self.__columns__: if conf['name'] == colname: column_conf = conf break if column_conf is None: raise ValueError("'%s' is not the name of a predefined column of table %s." % (colname, self)) if not column_conf.get('table', False): raise ValueError("Column '%s' must be a DynamicTableRegion to have a target table." % colname) self.add_column(name=column_conf['name'], description=column_conf['description'], index=column_conf.get('index', False), table=True) if isinstance(self[colname], VectorIndex): col = self[colname].target else: col = self[colname] col.table = table classdict['__init__'] = __init__ ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/multi.py0000644000655200065520000000143300000000000017263 0ustar00circlecicirclecifrom . import register_class from ..container import Container, Data, MultiContainerInterface from ..utils import docval, call_docval_func, popargs @register_class('SimpleMultiContainer') class SimpleMultiContainer(MultiContainerInterface): __clsconf__ = { 'attr': 'containers', 'type': (Container, Data), 'add': 'add_container', 'get': 'get_container', } @docval({'name': 'name', 'type': str, 'doc': 'the name of this container'}, {'name': 'containers', 'type': (list, tuple), 'default': None, 'doc': 'the Container or Data objects in this file'}) def __init__(self, **kwargs): containers = popargs('containers', kwargs) call_docval_func(super().__init__, kwargs) self.containers = containers ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/resources.py0000644000655200065520000004467400000000000020161 0ustar00circlecicircleciimport pandas as pd from . import register_class, EXP_NAMESPACE from ..container import Table, Row, Container, AbstractContainer from ..utils import docval, popargs class KeyTable(Table): """ A table for storing keys used to reference external resources """ __defaultname__ = 'keys' __columns__ = ( {'name': 'key', 'type': str, 'doc': 'The user key that maps to the resource term / registry symbol.'}, ) class Key(Row): """ A Row class for representing rows in the KeyTable """ __table__ = KeyTable class ResourceTable(Table): """ A table for storing names of ontology sources and their uri """ __defaultname__ = 'resources' __columns__ = ( {'name': 'resource', 'type': str, 'doc': 'The resource/registry that the term/symbol comes from.'}, {'name': 'resource_uri', 'type': str, 'doc': 'The URI for the resource term / registry symbol.'}, ) class Resource(Row): """ A Row class for representing rows in the ResourceTable """ __table__ = ResourceTable class EntityTable(Table): """ A table for storing the external resources a key refers to """ __defaultname__ = 'entities' __columns__ = ( {'name': 'keys_idx', 'type': (int, Key), 'doc': ('The index into the keys table for the user key that ' 'maps to the resource term / registry symbol.')}, {'name': 'resources_idx', 'type': (int, Resource), 'doc': 'The index into the ResourceTable.'}, {'name': 'entity_id', 'type': str, 'doc': 'The unique ID for the resource term / registry symbol.'}, {'name': 'entity_uri', 'type': str, 'doc': 'The URI for the resource term / registry symbol.'}, ) class Entity(Row): """ A Row class for representing rows in the EntityTable """ __table__ = EntityTable class ObjectTable(Table): """ A table for storing objects (i.e. Containers) that contain keys that refer to external resources """ __defaultname__ = 'objects' __columns__ = ( {'name': 'object_id', 'type': str, 'doc': 'The object ID for the Container/Data'}, {'name': 'field', 'type': str, 'doc': 'The field on the Container/Data that uses an external resource reference key'}, ) class Object(Row): """ A Row class for representing rows in the ObjectTable """ __table__ = ObjectTable class ObjectKeyTable(Table): """ A table for identifying which keys are used by which objects for referring to external resources """ __defaultname__ = 'object_keys' __columns__ = ( {'name': 'objects_idx', 'type': (int, Object), 'doc': 'the index into the objects table for the object that uses the key'}, {'name': 'keys_idx', 'type': (int, Key), 'doc': 'the index into the key table that is used to make an external resource reference'} ) class ObjectKey(Row): """ A Row class for representing rows in the ObjectKeyTable """ __table__ = ObjectKeyTable @register_class('ExternalResources', EXP_NAMESPACE) class ExternalResources(Container): """A table for mapping user terms (i.e. keys) to resource entities.""" __fields__ = ( {'name': 'keys', 'child': True}, {'name': 'resources', 'child': True}, {'name': 'objects', 'child': True}, {'name': 'object_keys', 'child': True}, {'name': 'entities', 'child': True}, ) @docval({'name': 'name', 'type': str, 'doc': 'the name of this ExternalResources container'}, {'name': 'keys', 'type': KeyTable, 'default': None, 'doc': 'the table storing user keys for referencing resources'}, {'name': 'resources', 'type': ResourceTable, 'default': None, 'doc': 'the table for storing names of resources and their uri'}, {'name': 'entities', 'type': EntityTable, 'default': None, 'doc': 'the table storing entity information'}, {'name': 'objects', 'type': ObjectTable, 'default': None, 'doc': 'the table storing object information'}, {'name': 'object_keys', 'type': ObjectKeyTable, 'default': None, 'doc': 'the table storing object-resource relationships'}) def __init__(self, **kwargs): name = popargs('name', kwargs) super().__init__(name) self.keys = kwargs['keys'] or KeyTable() self.resources = kwargs['resources'] or ResourceTable() self.entities = kwargs['entities'] or EntityTable() self.objects = kwargs['objects'] or ObjectTable() self.object_keys = kwargs['object_keys'] or ObjectKeyTable() @docval({'name': 'key_name', 'type': str, 'doc': 'the name of the key to be added'}) def _add_key(self, **kwargs): """ Add a key to be used for making references to external resources It is possible to use the same *key_name* to refer to different resources so long as the *key_name* is not used within the same object and field. To do so, this method must be called for the two different resources. The returned Key objects must be managed by the caller so as to be appropriately passed to subsequent calls to methods for storing information about the different resources. """ key = kwargs['key_name'] return Key(key, table=self.keys) @docval({'name': 'key', 'type': (str, Key), 'doc': 'the key to associate the entity with'}, {'name': 'resources_idx', 'type': (int, Resource), 'doc': 'the id of the resource'}, {'name': 'entity_id', 'type': str, 'doc': 'unique entity id'}, {'name': 'entity_uri', 'type': str, 'doc': 'the URI for the entity'}) def _add_entity(self, **kwargs): """ Add an entity that will be referenced to using the given key """ key = kwargs['key'] resources_idx = kwargs['resources_idx'] entity_id = kwargs['entity_id'] entity_uri = kwargs['entity_uri'] if not isinstance(key, Key): key = self._add_key(key) resource_entity = Entity(key, resources_idx, entity_id, entity_uri, table=self.entities) return resource_entity @docval({'name': 'resource', 'type': str, 'doc': 'the name of the ontology resource'}, {'name': 'uri', 'type': str, 'doc': 'uri associated with ontology resource'}) def _add_resource(self, **kwargs): """ Add resource name and uri to ResourceTable that will be referenced by the ResourceTable idx. """ resource_name = kwargs['resource'] uri = kwargs['uri'] resource = Resource(resource_name, uri, table=self.resources) return resource @docval({'name': 'container', 'type': (str, AbstractContainer), 'doc': 'the Container/Data object to add or the object_id for the Container/Data object to add'}, {'name': 'field', 'type': str, 'doc': 'the field on the Container to add'}) def _add_object(self, **kwargs): """ Add an object that references an external resource """ container, field = popargs('container', 'field', kwargs) if isinstance(container, AbstractContainer): container = container.object_id obj = Object(container, field, table=self.objects) return obj @docval({'name': 'obj', 'type': (int, Object), 'doc': 'the Object to that uses the Key'}, {'name': 'key', 'type': (int, Key), 'doc': 'the Key that the Object uses'}) def _add_object_key(self, **kwargs): """ Specify that an object (i.e. container and field) uses a key to reference an external resource """ obj, key = popargs('obj', 'key', kwargs) return ObjectKey(obj, key, table=self.object_keys) def _check_object_field(self, container, field): """ A helper function for checking if a container and field have been added. The container can be either an object_id string or a AbstractContainer. If the container and field have not been added, add the pair and return the corresponding Object. Otherwise, just return the Object. """ if isinstance(container, str): objecttable_idx = self.objects.which(object_id=container) else: objecttable_idx = self.objects.which(object_id=container.object_id) if len(objecttable_idx) > 0: field_idx = self.objects.which(field=field) objecttable_idx = list(set(objecttable_idx) & set(field_idx)) if len(objecttable_idx) == 1: return self.objects.row[objecttable_idx[0]] elif len(objecttable_idx) == 0: return self._add_object(container, field) else: raise ValueError("Found multiple instances of the same object_id and field in object table") @docval({'name': 'key_name', 'type': str, 'doc': 'the name of the key to get'}, {'name': 'container', 'type': (str, AbstractContainer), 'default': None, 'doc': ('the Container/Data object that uses the key or ' 'the object_id for the Container/Data object that uses the key')}, {'name': 'field', 'type': str, 'doc': 'the field of the Container that uses the key', 'default': None}) def get_key(self, **kwargs): """ Return a Key or a list of Key objects that correspond to the given key. If container and field are provided, the Key that corresponds to the given name of the key for the given container and field is returned. """ key_name, container, field = popargs('key_name', 'container', 'field', kwargs) key_idx_matches = self.keys.which(key=key_name) if container is not None and field is not None: # if same key is used multiple times, determine # which instance based on the Container object_field = self._check_object_field(container, field) for row_idx in self.object_keys.which(objects_idx=object_field.idx): key_idx = self.object_keys['keys_idx', row_idx] if key_idx in key_idx_matches: return self.keys.row[key_idx] raise ValueError("No key with name '%s' for container '%s' and field '%s'" % (key_name, container, field)) else: if len(key_idx_matches) == 0: # the key has never been used before raise ValueError("key '%s' does not exist" % key_name) elif len(key_idx_matches) > 1: return [self.keys.row[i] for i in key_idx_matches] else: return self.keys.row[key_idx_matches[0]] @docval({'name': 'resource_name', 'type': str, 'default': None}) def get_resource(self, **kwargs): """ Retrieve resource object with the given resource_name. """ resource_table_idx = self.resources.which(resource=kwargs['resource_name']) if len(resource_table_idx) == 0: # Resource hasn't been created msg = "No resource '%s' exists. Use _add_resource to create a new resource" % kwargs['resource_name'] raise ValueError(msg) else: return self.resources.row[resource_table_idx[0]] @docval({'name': 'container', 'type': (str, AbstractContainer), 'default': None, 'doc': ('the Container/Data object that uses the key or ' 'the object_id for the Container/Data object that uses the key')}, {'name': 'field', 'type': str, 'doc': 'the field of the Container/Data that uses the key', 'default': None}, {'name': 'key', 'type': (str, Key), 'default': None, 'doc': 'the name of the key or the Row object from the KeyTable for the key to add a resource for'}, {'name': 'resources_idx', 'type': Resource, 'doc': 'the resourcetable id', 'default': None}, {'name': 'resource_name', 'type': str, 'doc': 'the name of the resource to be created', 'default': None}, {'name': 'resource_uri', 'type': str, 'doc': 'the uri of the resource to be created', 'default': None}, {'name': 'entity_id', 'type': str, 'doc': 'the identifier for the entity at the resource', 'default': None}, {'name': 'entity_uri', 'type': str, 'doc': 'the URI for the identifier at the resource', 'default': None}) def add_ref(self, **kwargs): """ Add information about an external reference used in this file. It is possible to use the same name of the key to refer to different resources so long as the name of the key is not used within the same object and field. This method does not support such functionality by default. The different keys must be added separately using _add_key and passed to the *key* argument in separate calls of this method. If a resource with the same name already exists, then it will be used and the resource_uri will be ignored. """ container = kwargs['container'] field = kwargs['field'] key = kwargs['key'] entity_id = kwargs['entity_id'] entity_uri = kwargs['entity_uri'] add_entity = False object_field = self._check_object_field(container, field) if not isinstance(key, Key): key_idx_matches = self.keys.which(key=key) # if same key is used multiple times, determine # which instance based on the Container for row_idx in self.object_keys.which(objects_idx=object_field.idx): key_idx = self.object_keys['keys_idx', row_idx] if key_idx in key_idx_matches: msg = "Use Key Object when referencing an existing (container, field, key)" raise ValueError(msg) if not isinstance(key, Key): key = self._add_key(key) self._add_object_key(object_field, key) if kwargs['resources_idx'] is not None and kwargs['resource_name'] is None and kwargs['resource_uri'] is None: resource_table_idx = kwargs['resources_idx'] elif ( kwargs['resources_idx'] is not None and (kwargs['resource_name'] is not None or kwargs['resource_uri'] is not None)): msg = "Can't have resource_idx with resource_name or resource_uri." raise ValueError(msg) elif len(self.resources.which(resource=kwargs['resource_name'])) == 0: resource_name = kwargs['resource_name'] resource_uri = kwargs['resource_uri'] resource_table_idx = self._add_resource(resource_name, resource_uri) else: idx = self.resources.which(resource=kwargs['resource_name']) resource_table_idx = self.resources.row[idx[0]] if (resource_table_idx is not None and entity_id is not None and entity_uri is not None): add_entity = True elif not (resource_table_idx is None and entity_id is None and resource_uri is None): msg = ("Specify resource, entity_id, and entity_uri arguments." "All three are required to create a reference") raise ValueError(msg) if add_entity: entity = self._add_entity(key, resource_table_idx, entity_id, entity_uri) return key, resource_table_idx, entity @docval({'name': 'container', 'type': (str, AbstractContainer), 'doc': 'the Container/data object that is linked to resources/entities', 'default': None}, {'name': 'field', 'type': str, 'doc': 'the field of the Container', 'default': None}) def get_object_resources(self, **kwargs): """ Get all entities/resources associated with an object """ container = kwargs['container'] field = kwargs['field'] keys = [] entities = [] if container is not None and field is not None: object_field = self._check_object_field(container, field) # Find all keys associated with the object for row_idx in self.object_keys.which(objects_idx=object_field.idx): keys.append(self.object_keys['keys_idx', row_idx]) # Find all the entities/resources for each key. for key_idx in keys: entity_idx = self.entities.which(keys_idx=key_idx) entities.append(self.entities.__getitem__(entity_idx[0])) df = pd.DataFrame(entities, columns=['keys_idx', 'resource_idx', 'entity_id', 'entity_uri']) return df @docval({'name': 'keys', 'type': (list, Key), 'default': None, 'doc': 'the Key(s) to get external resource data for'}, rtype=pd.DataFrame, returns='a DataFrame with keys and external resource data') def get_keys(self, **kwargs): """ Return a DataFrame with information about keys used to make references to external resources. The DataFrame will contain the following columns: - *key_name*: the key that will be used for referencing an external resource - *resources_idx*: the index for the resourcetable - *entity_id*: the index for the entity at the external resource - *entity_uri*: the URI for the entity at the external resource It is possible to use the same *key_name* to refer to different resources so long as the *key_name* is not used within the same object and field. This method does not support such functionality by default. To select specific keys, use the *keys* argument to pass in the Key object(s) representing the desired keys. Note, if the same *key_name* is used more than once, multiple calls to this method with different Key objects will be required to keep the different instances separate. If a single call is made, it is left up to the caller to distinguish the different instances. """ keys = popargs('keys', kwargs) if keys is None: keys = [self.keys.row[i] for i in range(len(self.keys))] else: if not isinstance(keys, list): keys = [keys] data = list() for key in keys: rsc_ids = self.entities.which(keys_idx=key.idx) for rsc_id in rsc_ids: rsc_row = self.entities.row[rsc_id].todict() rsc_row.pop('keys_idx') rsc_row['key_name'] = key.key data.append(rsc_row) return pd.DataFrame(data=data, columns=['key_name', 'resources_idx', 'entity_id', 'entity_uri']) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/sparse.py0000644000655200065520000000540300000000000017427 0ustar00circlecicircleciimport h5py import numpy as np import scipy.sparse as sps from . import register_class from ..container import Container from ..utils import docval, getargs, call_docval_func, to_uint_array @register_class('CSRMatrix') class CSRMatrix(Container): @docval({'name': 'data', 'type': (sps.csr_matrix, np.ndarray, h5py.Dataset), 'doc': 'the data to use for this CSRMatrix or CSR data array.' 'If passing CSR data array, *indices*, *indptr*, and *shape* must also be provided'}, {'name': 'indices', 'type': (np.ndarray, h5py.Dataset), 'doc': 'CSR index array', 'default': None}, {'name': 'indptr', 'type': (np.ndarray, h5py.Dataset), 'doc': 'CSR index pointer array', 'default': None}, {'name': 'shape', 'type': (list, tuple, np.ndarray), 'doc': 'the shape of the matrix', 'default': None}, {'name': 'name', 'type': str, 'doc': 'the name to use for this when storing', 'default': 'csr_matrix'}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) data = getargs('data', kwargs) if isinstance(data, (np.ndarray, h5py.Dataset)): if data.ndim == 2: data = sps.csr_matrix(data) elif data.ndim < 2: indptr, indices, shape = getargs('indptr', 'indices', 'shape', kwargs) if any(_ is None for _ in (indptr, indices, shape)): raise ValueError("Must specify 'indptr', 'indices', and 'shape' arguments when passing data array.") indptr = self.__check_arr(indptr, 'indptr') indices = self.__check_arr(indices, 'indices') shape = self.__check_arr(shape, 'shape') if len(shape) != 2: raise ValueError("'shape' argument must specify two and only two dimensions.") data = sps.csr_matrix((data, indices, indptr), shape=shape) else: raise ValueError("'data' argument cannot be ndarray of dimensionality > 2.") self.__data = data @staticmethod def __check_arr(ar, arg): try: ar = to_uint_array(ar) except ValueError as ve: raise ValueError("Cannot convert '%s' to an array of unsigned integers." % arg) from ve if ar.ndim != 1: raise ValueError("'%s' must be a 1D array of unsigned integers." % arg) return ar def __getattr__(self, val): # NOTE: this provides access to self.data, self.indices, self.indptr, self.shape attr = getattr(self.__data, val) if val in ('indices', 'indptr', 'shape'): # needed because sps.csr_matrix may contain int arrays for these attr = to_uint_array(attr) return attr def to_spmat(self): return self.__data ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/common/table.py0000644000655200065520000021013600000000000017222 0ustar00circlecicircleci""" Collection of Container classes for interacting with data types related to the storage and use of dynamic data tables as part of the hdmf-common schema """ import re from collections import OrderedDict from typing import NamedTuple, Union from warnings import warn import numpy as np import pandas as pd from . import register_class, EXP_NAMESPACE from ..container import Container, Data from ..data_utils import DataIO, AbstractDataChunkIterator from ..utils import docval, getargs, ExtenderMeta, call_docval_func, popargs, pystr @register_class('VectorData') class VectorData(Data): """ A n-dimensional dataset representing a column of a DynamicTable. If used without an accompanying VectorIndex, first dimension is along the rows of the DynamicTable and each step along the first dimension is a cell of the larger table. VectorData can also be used to represent a ragged array if paired with a VectorIndex. This allows for storing arrays of varying length in a single cell of the DynamicTable by indexing into this VectorData. The first vector is at VectorData[0:VectorIndex(0)+1]. The second vector is at VectorData[VectorIndex(0)+1:VectorIndex(1)+1], and so on. """ __fields__ = ("description",) @docval({'name': 'name', 'type': str, 'doc': 'the name of this VectorData'}, {'name': 'description', 'type': str, 'doc': 'a description for this column'}, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a dataset where the first dimension is a concatenation of multiple vectors', 'default': list()}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) self.description = getargs('description', kwargs) @docval({'name': 'val', 'type': None, 'doc': 'the value to add to this column'}) def add_row(self, **kwargs): """Append a data value to this VectorData column""" val = getargs('val', kwargs) self.append(val) def get(self, key, **kwargs): """ Retrieve elements from this VectorData :param key: Selection of the elements :param **kwargs: Ignored """ return super().get(key) def extend(self, ar, **kwargs): """Add all elements of the iterable arg to the end of this VectorData. Each subclass of VectorData should have its own extend method to ensure functionality and efficiency. :param arg: The iterable to add to the end of this VectorData """ ################################################################################# # Each subclass of VectorData should have its own extend method to ensure # functionality AND efficiency of the extend operation. However, because currently # they do not all have one of these methods, the only way to ensure functionality # is with calls to add_row. Because that is inefficient for basic VectorData, # this check is added to ensure we always call extend on a basic VectorData. if self.__class__.__mro__[0] == VectorData: super().extend(ar) else: for i in ar: self.add_row(i, **kwargs) @register_class('VectorIndex') class VectorIndex(VectorData): """ When paired with a VectorData, this allows for storing arrays of varying length in a single cell of the DynamicTable by indexing into this VectorData. The first vector is at VectorData[0:VectorIndex(0)+1]. The second vector is at VectorData[VectorIndex(0)+1:VectorIndex(1)+1], and so on. """ __fields__ = ("target",) @docval({'name': 'name', 'type': str, 'doc': 'the name of this VectorIndex'}, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a 1D dataset containing indexes that apply to VectorData object'}, {'name': 'target', 'type': VectorData, 'doc': 'the target dataset that this index applies to'}) def __init__(self, **kwargs): target = getargs('target', kwargs) kwargs['description'] = "Index for VectorData '%s'" % target.name call_docval_func(super().__init__, kwargs) self.target = target self.__uint = np.uint8 self.__maxval = 255 if isinstance(self.data, (list, np.ndarray)): if len(self.data) > 0: self.__check_precision(len(self.target)) # adjust precision for types that we can adjust precision for self.__adjust_precision(self.__uint) def add_vector(self, arg, **kwargs): """ Add the given data value to the target VectorData and append the corresponding index to this VectorIndex :param arg: The data value to be added to self.target """ if isinstance(self.target, VectorIndex): for a in arg: self.target.add_vector(a) else: self.target.extend(arg, **kwargs) self.append(self.__check_precision(len(self.target))) def __check_precision(self, idx): """ Check precision of current dataset and, if necessary, adjust precision to accommodate new value. Returns: unsigned integer encoding of idx """ if idx > self.__maxval: while idx > self.__maxval: nbits = (np.log2(self.__maxval + 1) * 2) # 8->16, 16->32, 32->64 if nbits == 128: # pragma: no cover msg = ('Cannot store more than 18446744073709551615 elements in a VectorData. Largest dtype ' 'allowed for VectorIndex is uint64.') raise ValueError(msg) self.__maxval = 2 ** nbits - 1 self.__uint = np.dtype('uint%d' % nbits).type self.__adjust_precision(self.__uint) return self.__uint(idx) def __adjust_precision(self, uint): """ Adjust precision of data to specificied unsigned integer precision. """ if isinstance(self.data, list): for i in range(len(self.data)): self.data[i] = uint(self.data[i]) elif isinstance(self.data, np.ndarray): # use self._Data__data to work around restriction on resetting self.data self._Data__data = self.data.astype(uint) else: raise ValueError("cannot adjust precision of type %s to %s", (type(self.data), uint)) def add_row(self, arg, **kwargs): """ Convenience function. Same as :py:func:`add_vector` """ self.add_vector(arg, **kwargs) def __getitem_helper(self, arg, **kwargs): """ Internal helper function used by __getitem__ to retrieve a data value from self.target :param arg: Integer index into this VectorIndex indicating the element we want to retrieve from the target :param kwargs: any additional arguments to *get* method of the self.target VectorData :return: Scalar or list of values retrieved """ start = 0 if arg == 0 else self.data[arg - 1] end = self.data[arg] return self.target.get(slice(start, end), **kwargs) def __getitem__(self, arg): """ Select elements in this VectorIndex and retrieve the corresponding data from the self.target VectorData :param arg: slice or integer index indicating the elements we want to select in this VectorIndex :return: Scalar or list of values retrieved """ return self.get(arg) def get(self, arg, **kwargs): """ Select elements in this VectorIndex and retrieve the corresponding data from the self.target VectorData :param arg: slice or integer index indicating the elements we want to select in this VectorIndex :param kwargs: any additional arguments to *get* method of the self.target VectorData :return: Scalar or list of values retrieved """ if np.isscalar(arg): return self.__getitem_helper(arg, **kwargs) else: if isinstance(arg, slice): indices = list(range(*arg.indices(len(self.data)))) else: if isinstance(arg[0], bool): arg = np.where(arg)[0] indices = arg ret = list() for i in indices: ret.append(self.__getitem_helper(i, **kwargs)) return ret @register_class('ElementIdentifiers') class ElementIdentifiers(Data): """ Data container with a list of unique identifiers for values within a dataset, e.g. rows of a DynamicTable. """ @docval({'name': 'name', 'type': str, 'doc': 'the name of this ElementIdentifiers'}, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a 1D dataset containing identifiers', 'default': list()}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) @docval({'name': 'other', 'type': (Data, np.ndarray, list, tuple, int), 'doc': 'List of ids to search for in this ElementIdentifer object'}, rtype=np.ndarray, returns='Array with the list of indices where the elements in the list where found.' 'Note, the elements in the returned list are ordered in increasing index' 'of the found elements, rather than in the order in which the elements' 'where given for the search. Also the length of the result may be different from the length' 'of the input array. E.g., if our ids are [1,2,3] and we are search for [3,1,5] the ' 'result would be [0,2] and NOT [2,0,None]') def __eq__(self, other): """ Given a list of ids return the indices in the ElementIdentifiers array where the indices are found. """ # Determine the ids we want to find search_ids = other if not isinstance(other, Data) else other.data if isinstance(search_ids, int): search_ids = [search_ids] # Find all matching locations return np.in1d(self.data, search_ids).nonzero()[0] @register_class('DynamicTable') class DynamicTable(Container): r""" A column-based table. Columns are defined by the argument *columns*. This argument must be a list/tuple of :class:`~hdmf.common.table.VectorData` and :class:`~hdmf.common.table.VectorIndex` objects or a list/tuple of dicts containing the keys ``name`` and ``description`` that provide the name and description of each column in the table. Additionally, the keys ``index``, ``table``, ``enum`` can be used for specifying additional structure to the table columns. Setting the key ``index`` to ``True`` can be used to indicate that the :class:`~hdmf.common.table.VectorData` column will store a ragged array (i.e. will be accompanied with a :class:`~hdmf.common.table.VectorIndex`). Setting the key ``table`` to ``True`` can be used to indicate that the column will store regions to another DynamicTable. Setting the key ``enum`` to ``True`` can be used to indicate that the column data will come from a fixed set of values. Columns in DynamicTable subclasses can be statically defined by specifying the class attribute *\_\_columns\_\_*, rather than specifying them at runtime at the instance level. This is useful for defining a table structure that will get reused. The requirements for *\_\_columns\_\_* are the same as the requirements described above for specifying table columns with the *columns* argument to the DynamicTable constructor. """ __fields__ = ( {'name': 'id', 'child': True}, {'name': 'columns', 'child': True}, 'colnames', 'description' ) __columns__ = tuple() @ExtenderMeta.pre_init def __gather_columns(cls, name, bases, classdict): r""" Gather columns from the *\_\_columns\_\_* class attribute and add them to the class. This classmethod will be called during class declaration in the metaclass to automatically include all columns declared in subclasses. """ if not isinstance(cls.__columns__, tuple): msg = "'__columns__' must be of type tuple, found %s" % type(cls.__columns__) raise TypeError(msg) if (len(bases) and 'DynamicTable' in globals() and issubclass(bases[-1], Container) and bases[-1].__columns__ is not cls.__columns__): new_columns = list(cls.__columns__) new_columns[0:0] = bases[-1].__columns__ # prepend superclass columns to new_columns cls.__columns__ = tuple(new_columns) @docval({'name': 'name', 'type': str, 'doc': 'the name of this table'}, # noqa: C901 {'name': 'description', 'type': str, 'doc': 'a description of what is in this table'}, {'name': 'id', 'type': ('array_data', 'data', ElementIdentifiers), 'doc': 'the identifiers for this table', 'default': None}, {'name': 'columns', 'type': (tuple, list), 'doc': 'the columns in this table', 'default': None}, {'name': 'colnames', 'type': 'array_data', 'doc': 'the ordered names of the columns in this table. columns must also be provided.', 'default': None}) def __init__(self, **kwargs): # noqa: C901 id, columns, desc, colnames = popargs('id', 'columns', 'description', 'colnames', kwargs) call_docval_func(super().__init__, kwargs) self.description = desc # hold names of optional columns that are defined in __columns__ that are not yet initialized # map name to column specification self.__uninit_cols = dict() # All tables must have ElementIdentifiers (i.e. a primary key column) # Here, we figure out what to do for that if id is not None: if not isinstance(id, ElementIdentifiers): id = ElementIdentifiers('id', data=id) else: id = ElementIdentifiers('id') if columns is not None and len(columns) > 0: # If columns have been passed in, check them over and process accordingly if isinstance(columns[0], dict): columns = self.__build_columns(columns) elif not all(isinstance(c, VectorData) for c in columns): raise ValueError("'columns' must be a list of dict, VectorData, DynamicTableRegion, or VectorIndex") all_names = [c.name for c in columns] if len(all_names) != len(set(all_names)): raise ValueError("'columns' contains columns with duplicate names: %s" % all_names) all_targets = [c.target.name for c in columns if isinstance(c, VectorIndex)] if len(all_targets) != len(set(all_targets)): raise ValueError("'columns' contains index columns with the same target: %s" % all_targets) # TODO: check columns against __columns__ # mismatches should raise an error (e.g., a VectorData cannot be passed in with the same name as a # prespecified table region column) # check column lengths against each other and id length # set ids if non-zero cols are provided and ids is empty colset = {c.name: c for c in columns} for c in columns: # remove all VectorData objects that have an associated VectorIndex from colset if isinstance(c, VectorIndex): if c.target.name in colset: colset.pop(c.target.name) else: raise ValueError("Found VectorIndex '%s' but not its target '%s'" % (c.name, c.target.name)) elif isinstance(c, EnumData): if c.elements.name in colset: colset.pop(c.elements.name) _data = c.data if isinstance(_data, DataIO): _data = _data.data if isinstance(_data, AbstractDataChunkIterator): colset.pop(c.name, None) lens = [len(c) for c in colset.values()] if not all(i == lens[0] for i in lens): raise ValueError("columns must be the same length") if len(lens) > 0 and lens[0] != len(id): # the first part of this conditional is needed in the # event that all columns are AbstractDataChunkIterators if len(id) > 0: raise ValueError("must provide same number of ids as length of columns") else: # set ids to: 0 to length of columns - 1 id.data.extend(range(lens[0])) self.id = id # NOTE: self.colnames and self.columns are always tuples # if kwarg colnames is an h5dataset, self.colnames is still a tuple if colnames is None or len(colnames) == 0: if columns is None: # make placeholder for columns if nothing was given self.colnames = tuple() self.columns = tuple() else: # Figure out column names if columns were given tmp = OrderedDict() skip = set() for col in columns: if col.name in skip: continue if isinstance(col, VectorIndex): continue if isinstance(col, EnumData): skip.add(col.elements.name) tmp.pop(col.elements.name, None) tmp[col.name] = None self.colnames = tuple(tmp) self.columns = tuple(columns) else: # Calculate the order of column names if columns is None: raise ValueError("Must supply 'columns' if specifying 'colnames'") else: # order the columns according to the column names, which does not include indices self.colnames = tuple(pystr(c) for c in colnames) col_dict = {col.name: col for col in columns} # map from vectordata name to list of vectorindex objects where target of last vectorindex is vectordata indices = dict() # determine which columns are indexed by another column for col in columns: if isinstance(col, VectorIndex): # loop through nested indices to get to non-index column tmp_indices = [col] curr_col = col while isinstance(curr_col.target, VectorIndex): curr_col = curr_col.target tmp_indices.append(curr_col) # make sure the indices values has the full index chain, so replace existing value if it is # shorter if len(tmp_indices) > len(indices.get(curr_col.target.name, [])): indices[curr_col.target.name] = tmp_indices elif isinstance(col, EnumData): # EnumData is the indexing column, so it should go first if col.name not in indices: indices[col.name] = [col] # EnumData is the indexing object col_dict[col.name] = col.elements # EnumData.elements is the column with values else: if col.name in indices: continue indices[col.name] = [] # put columns in order of colnames, with indices before the target vectordata tmp = [] for name in self.colnames: tmp.extend(indices[name]) tmp.append(col_dict[name]) self.columns = tuple(tmp) # to make generating DataFrames and Series easier col_dict = dict() self.__indices = dict() for col in self.columns: if isinstance(col, VectorIndex): # if index has already been added because it is part of a nested index chain, ignore this column if col.name in self.__indices: continue self.__indices[col.name] = col # loop through nested indices to get to non-index column curr_col = col self.__set_table_attr(curr_col) while isinstance(curr_col.target, VectorIndex): curr_col = curr_col.target # check if index has been added. if not, add it if not hasattr(self, curr_col.name): self.__set_table_attr(curr_col) self.__indices[curr_col.name] = col # use target vectordata name at end of indexing chain as key to get to the top level index col_dict[curr_col.target.name] = col if not hasattr(self, curr_col.target.name): self.__set_table_attr(curr_col.target) else: # this is a regular VectorData or EnumData # if we added this column using its index, ignore this column if col.name in col_dict: continue else: col_dict[col.name] = col self.__set_table_attr(col) self.__df_cols = [self.id] + [col_dict[name] for name in self.colnames] # self.__colids maps the column name to an index starting at 1 self.__colids = {name: i + 1 for i, name in enumerate(self.colnames)} self._init_class_columns() def __set_table_attr(self, col): if hasattr(self, col.name) and col.name not in self.__uninit_cols: msg = ("An attribute '%s' already exists on %s '%s' so this column cannot be accessed as an attribute, " "e.g., table.%s; it can only be accessed using other methods, e.g., table['%s']." % (col.name, self.__class__.__name__, self.name, col.name, col.name)) warn(msg) else: setattr(self, col.name, col) __reserved_colspec_keys = ['name', 'description', 'index', 'table', 'required', 'class'] def _init_class_columns(self): """ Process all predefined columns specified in class variable __columns__. Optional columns are not tracked but not added. """ for col in self.__columns__: if col['name'] not in self.__colids: # if column has not been added in __init__ if col.get('required', False): self.add_column(name=col['name'], description=col['description'], index=col.get('index', False), table=col.get('table', False), col_cls=col.get('class', VectorData), # Pass through extra kwargs for add_column that subclasses may have added **{k: col[k] for k in col.keys() if k not in DynamicTable.__reserved_colspec_keys}) else: # track the not yet initialized optional predefined columns self.__uninit_cols[col['name']] = col # set the table attributes for not yet init optional predefined columns setattr(self, col['name'], None) index = col.get('index', False) if index is not False: if index is True: index = 1 if isinstance(index, int): assert index > 0, ValueError("integer index value must be greater than 0") index_name = col['name'] for i in range(index): index_name = index_name + '_index' self.__uninit_cols[index_name] = col setattr(self, index_name, None) if col.get('enum', False): self.__uninit_cols[col['name'] + '_elements'] = col setattr(self, col['name'] + '_elements', None) @staticmethod def __build_columns(columns, df=None): """ Build column objects according to specifications """ tmp = list() for d in columns: name = d['name'] desc = d.get('description', 'no description') col_cls = d.get('class', VectorData) data = None if df is not None: data = list(df[name].values) index = d.get('index', False) if index is not False: if isinstance(index, int) and index > 1: raise ValueError('Creating nested index columns using this method is not yet supported. Use ' 'add_column or define the columns using __columns__ instead.') index_data = None if data is not None: index_data = [len(data[0])] for i in range(1, len(data)): index_data.append(len(data[i]) + index_data[i - 1]) # assume data came in through a DataFrame, so we need # to concatenate it tmp_data = list() for d in data: tmp_data.extend(d) data = tmp_data vdata = col_cls(name, desc, data=data) vindex = VectorIndex("%s_index" % name, index_data, target=vdata) tmp.append(vindex) tmp.append(vdata) elif d.get('enum', False): # EnumData is the indexing column, so it should go first if data is not None: elements, data = np.unique(data, return_inverse=True) tmp.append(EnumData(name, desc, data=data, elements=elements)) else: tmp.append(EnumData(name, desc, data=data)) # EnumData handles constructing the VectorData object that contains EnumData.elements # --> use this functionality (rather than creating here) for consistency and less code/complexity tmp.append(tmp[-1].elements) else: if data is None: data = list() if d.get('table', False): col_cls = DynamicTableRegion tmp.append(col_cls(name, desc, data=data)) return tmp def __len__(self): """Number of rows in the table""" return len(self.id) @docval({'name': 'data', 'type': dict, 'doc': 'the data to put in this row', 'default': None}, {'name': 'id', 'type': int, 'doc': 'the ID for the row', 'default': None}, {'name': 'enforce_unique_id', 'type': bool, 'doc': 'enforce that the id in the table must be unique', 'default': False}, allow_extra=True) def add_row(self, **kwargs): """ Add a row to the table. If *id* is not provided, it will auto-increment. """ data, row_id, enforce_unique_id = popargs('data', 'id', 'enforce_unique_id', kwargs) data = data if data is not None else kwargs extra_columns = set(list(data.keys())) - set(list(self.__colids.keys())) missing_columns = set(list(self.__colids.keys())) - set(list(data.keys())) # check to see if any of the extra columns just need to be added if extra_columns: for col in self.__columns__: if col['name'] in extra_columns: if data[col['name']] is not None: self.add_column(col['name'], col['description'], index=col.get('index', False), table=col.get('table', False), enum=col.get('enum', False), col_cls=col.get('class', VectorData), # Pass through extra keyword arguments for add_column that # subclasses may have added **{k: col[k] for k in col.keys() if k not in DynamicTable.__reserved_colspec_keys}) extra_columns.remove(col['name']) if extra_columns or missing_columns: raise ValueError( '\n'.join([ 'row data keys don\'t match available columns', 'you supplied {} extra keys: {}'.format(len(extra_columns), extra_columns), 'and were missing {} keys: {}'.format(len(missing_columns), missing_columns) ]) ) if row_id is None: row_id = data.pop('id', None) if row_id is None: row_id = len(self) if enforce_unique_id: if row_id in self.id: raise ValueError("id %i already in the table" % row_id) self.id.append(row_id) for colname, colnum in self.__colids.items(): if colname not in data: raise ValueError("column '%s' missing" % colname) c = self.__df_cols[colnum] if isinstance(c, VectorIndex): c.add_vector(data[colname]) else: c.add_row(data[colname]) def __eq__(self, other): """Compare if the two DynamicTables contain the same data. First this returns False if the other DynamicTable has a different name or description. Then, this table and the other table are converted to pandas dataframes and the equality of the two tables is returned. :param other: DynamicTable to compare to :return: Bool indicating whether the two DynamicTables contain the same data """ if other is self: return True if not isinstance(other, DynamicTable): return False if self.name != other.name or self.description != other.description: return False return self.to_dataframe().equals(other.to_dataframe()) @docval({'name': 'name', 'type': str, 'doc': 'the name of this VectorData'}, # noqa: C901 {'name': 'description', 'type': str, 'doc': 'a description for this column'}, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a dataset where the first dimension is a concatenation of multiple vectors', 'default': list()}, {'name': 'table', 'type': (bool, 'DynamicTable'), 'doc': 'whether or not this is a table region or the table the region applies to', 'default': False}, {'name': 'index', 'type': (bool, VectorIndex, 'array_data', int), 'doc': 'False (default): do not generate a VectorIndex \n' 'True: generate one empty VectorIndex \n' 'VectorIndex: Use the supplied VectorIndex \n' 'array-like of ints: Create a VectorIndex and use these values as the data \n' 'int: Recursively create `n` VectorIndex objects for a multi-ragged array \n', 'default': False}, {'name': 'enum', 'type': (bool, 'array_data'), 'default': False, 'doc': ('whether or not this column contains data from a fixed set of elements')}, {'name': 'col_cls', 'type': type, 'default': VectorData, 'doc': ('class to use to represent the column data. If table=True, this field is ignored and a ' 'DynamicTableRegion object is used. If enum=True, this field is ignored and a EnumData ' 'object is used.')}, ) def add_column(self, **kwargs): # noqa: C901 """ Add a column to this table. If data is provided, it must contain the same number of rows as the current state of the table. :raises ValueError: if the column has already been added to the table """ name, data = getargs('name', 'data', kwargs) index, table, enum, col_cls = popargs('index', 'table', 'enum', 'col_cls', kwargs) if isinstance(index, VectorIndex): warn("Passing a VectorIndex in for index may lead to unexpected behavior. This functionality will be " "deprecated in a future version of HDMF.", FutureWarning) if name in self.__colids: # column has already been added msg = "column '%s' already exists in %s '%s'" % (name, self.__class__.__name__, self.name) raise ValueError(msg) if name in self.__uninit_cols: # column is a predefined optional column from the spec # check the given values against the predefined optional column spec. if they do not match, raise a warning # and ignore the given arguments. users should not be able to override these values table_bool = table or not isinstance(table, bool) spec_table = self.__uninit_cols[name].get('table', False) if table_bool != spec_table: msg = ("Column '%s' is predefined in %s with table=%s which does not match the entered " "table argument. The predefined table spec will be ignored. " "Please ensure the new column complies with the spec. " "This will raise an error in a future version of HDMF." % (name, self.__class__.__name__, spec_table)) warn(msg) index_bool = index or not isinstance(index, bool) spec_index = self.__uninit_cols[name].get('index', False) if index_bool != spec_index: msg = ("Column '%s' is predefined in %s with index=%s which does not match the entered " "index argument. The predefined index spec will be ignored. " "Please ensure the new column complies with the spec. " "This will raise an error in a future version of HDMF." % (name, self.__class__.__name__, spec_index)) warn(msg) spec_col_cls = self.__uninit_cols[name].get('class', VectorData) if col_cls != spec_col_cls: msg = ("Column '%s' is predefined in %s with class=%s which does not match the entered " "col_cls argument. The predefined class spec will be ignored. " "Please ensure the new column complies with the spec. " "This will raise an error in a future version of HDMF." % (name, self.__class__.__name__, spec_col_cls)) warn(msg) ckwargs = dict(kwargs) # Add table if it's been specified if table and enum: raise ValueError("column '%s' cannot be both a table region " "and come from an enumerable set of elements" % name) if table is not False: col_cls = DynamicTableRegion if isinstance(table, DynamicTable): ckwargs['table'] = table if enum is not False: col_cls = EnumData if isinstance(enum, (list, tuple, np.ndarray, VectorData)): ckwargs['elements'] = enum col = col_cls(**ckwargs) col.parent = self columns = [col] self.__set_table_attr(col) if col in self.__uninit_cols: self.__uninit_cols.pop(col) if col_cls is EnumData: columns.append(col.elements) col.elements.parent = self # Add index if it's been specified if index is not False: if isinstance(index, VectorIndex): col_index = index self.__add_column_index_helper(col_index) elif isinstance(index, bool): # make empty VectorIndex if len(col) > 0: raise ValueError("cannot pass empty index with non-empty data to index") col_index = VectorIndex(name + "_index", list(), col) self.__add_column_index_helper(col_index) elif isinstance(index, int): assert index > 0, ValueError("integer index value must be greater than 0") assert len(col) == 0, ValueError("cannot pass empty index with non-empty data to index") index_name = name for i in range(index): index_name = index_name + "_index" col_index = VectorIndex(index_name, list(), col) self.__add_column_index_helper(col_index) if i < index - 1: columns.insert(0, col_index) col = col_index else: # make VectorIndex with supplied data if len(col) == 0: raise ValueError("cannot pass non-empty index with empty data to index") col_index = VectorIndex(name + "_index", index, col) self.__add_column_index_helper(col_index) columns.insert(0, col_index) col = col_index if len(col) != len(self.id): raise ValueError("column must have the same number of rows as 'id'") self.__colids[name] = len(self.__df_cols) self.fields['colnames'] = tuple(list(self.colnames) + [name]) self.fields['columns'] = tuple(list(self.columns) + columns) self.__df_cols.append(col) def __add_column_index_helper(self, col_index): if not isinstance(col_index.parent, Container): col_index.parent = self # else, the ObjectMapper will create a link from self (parent) to col_index (child with existing parent) self.__indices[col_index.name] = col_index self.__set_table_attr(col_index) if col_index in self.__uninit_cols: self.__uninit_cols.pop(col_index) @docval({'name': 'name', 'type': str, 'doc': 'the name of the DynamicTableRegion object'}, {'name': 'region', 'type': (slice, list, tuple), 'doc': 'the indices of the table'}, {'name': 'description', 'type': str, 'doc': 'a brief description of what the region is'}) def create_region(self, **kwargs): """ Create a DynamicTableRegion selecting a region (i.e., rows) in this DynamicTable. :raises: IndexError if the provided region contains invalid indices """ region = getargs('region', kwargs) if isinstance(region, slice): if (region.start is not None and region.start < 0) or (region.stop is not None and region.stop > len(self)): msg = 'region slice %s is out of range for this DynamicTable of length %d' % (str(region), len(self)) raise IndexError(msg) region = list(range(*region.indices(len(self)))) else: for idx in region: if idx < 0 or idx >= len(self): raise IndexError('The index ' + str(idx) + ' is out of range for this DynamicTable of length ' + str(len(self))) desc = getargs('description', kwargs) name = getargs('name', kwargs) return DynamicTableRegion(name, region, desc, self) def __getitem__(self, key): ret = self.get(key) if ret is None: raise KeyError(key) return ret def get(self, key, default=None, df=True, index=True, **kwargs): """Select a subset from the table. If the table includes a DynamicTableRegion column, then by default, the index/indices of the DynamicTableRegion will be returned. If ``df=True`` and ``index=False``, then the returned pandas DataFrame will contain a nested DataFrame in each row of the DynamicTableRegion column. If ``df=False`` and ``index=True``, then a list of lists will be returned where the list containing the DynamicTableRegion column contains the indices of the DynamicTableRegion. Note that in this case, the DynamicTable referenced by the DynamicTableRegion can be accessed through the ``table`` attribute of the DynamicTableRegion object. ``df=False`` and ``index=False`` is not yet supported. :param key: Key defining which elements of the table to select. This may be one of the following: 1) string with the name of the column to select 2) a tuple consisting of (int, str) where the int selects the row and the string identifies the column to select by name 3) int, list of ints, array, or slice selecting a set of full rows in the table. If an int is used, then scalars are returned for each column that has a single value. If a list, array, or slice is used and df=False, then lists are returned for each column, even if the list, array, or slice resolves to a single row. :return: 1) If key is a string, then return the VectorData object representing the column with the string name 2) If key is a tuple of (int, str), then return the scalar value of the selected cell 3) If key is an int, list, np.ndarray, or slice, then return pandas.DataFrame or lists consisting of one or more rows :raises: KeyError """ ret = None if not df and not index: # returning nested lists of lists for DTRs and ragged DTRs is complicated and not yet supported raise ValueError('DynamicTable.get() with df=False and index=False is not yet supported.') if isinstance(key, tuple): # index by row and column --> return specific cell arg1 = key[0] arg2 = key[1] if isinstance(arg2, str): arg2 = self.__colids[arg2] ret = self.__df_cols[arg2][arg1] elif isinstance(key, str): # index by one string --> return column if key == 'id': return self.id elif key in self.__colids: ret = self.__df_cols[self.__colids[key]] elif key in self.__indices: ret = self.__indices[key] else: return default else: # index by int, list, np.ndarray, or slice --> # return pandas Dataframe or lists consisting of one or more rows sel = self.__get_selection_as_dict(key, df, index, **kwargs) if df: # reformat objects to fit into a pandas DataFrame if np.isscalar(key): ret = self.__get_selection_as_df_single_row(sel) else: ret = self.__get_selection_as_df(sel) else: ret = list(sel.values()) return ret def __get_selection_as_dict(self, arg, df, index, exclude=None, **kwargs): """Return a dict mapping column names to values (lists/arrays or dataframes) for the given selection. Uses each column's get() method, passing kwargs as necessary. :param arg: key passed to get() to return one or more rows :type arg: int, list, np.ndarray, or slice """ if not (np.issubdtype(type(arg), np.integer) or isinstance(arg, (slice, list, np.ndarray))): raise KeyError("Key type not supported by DynamicTable %s" % str(type(arg))) if isinstance(arg, np.ndarray) and arg.ndim != 1: raise ValueError("Cannot index DynamicTable with multiple dimensions") if exclude is None: exclude = set([]) ret = OrderedDict() try: # index with a python slice or single int to select one or multiple rows ret['id'] = self.id[arg] for name in self.colnames: if name in exclude: continue col = self.__df_cols[self.__colids[name]] if index and (isinstance(col, DynamicTableRegion) or (isinstance(col, VectorIndex) and isinstance(col.target, DynamicTableRegion))): # return indices (in list, array, etc.) for DTR and ragged DTR ret[name] = col.get(arg, df=False, index=True, **kwargs) else: ret[name] = col.get(arg, df=df, index=index, **kwargs) return ret # if index is out of range, different errors can be generated depending on the dtype of the column # but despite the differences, raise an IndexError from that error except ValueError as ve: # in h5py <2, if the column is an h5py.Dataset, a ValueError was raised # in h5py 3+, this became an IndexError x = re.match(r"^Index \((.*)\) out of range \(.*\)$", str(ve)) if x: msg = ("Row index %s out of range for %s '%s' (length %d)." % (x.groups()[0], self.__class__.__name__, self.name, len(self))) raise IndexError(msg) from ve else: # pragma: no cover raise ve except IndexError as ie: x = re.match(r"^Index \((.*)\) out of range for \(.*\)$", str(ie)) if x: msg = ("Row index %s out of range for %s '%s' (length %d)." % (x.groups()[0], self.__class__.__name__, self.name, len(self))) raise IndexError(msg) elif str(ie) == 'list index out of range': msg = ("Row index out of range for %s '%s' (length %d)." % (self.__class__.__name__, self.name, len(self))) raise IndexError(msg) from ie else: # pragma: no cover raise ie def __get_selection_as_df_single_row(self, coldata): """Return a pandas dataframe for the given row and columns with the id column as the index. This is a special case of __get_selection_as_df where a single row was requested. :param coldata: dict mapping column names to values (list/arrays or dataframes) :type coldata: dict """ id_index_orig = coldata.pop('id') id_index = [id_index_orig] df_input = OrderedDict() for k in coldata: # for each column if isinstance(coldata[k], (np.ndarray, list, tuple, pd.DataFrame)): # wrap in a list because coldata[k] may be an array/list/tuple with multiple elements (ragged or # multi-dim column) and pandas needs to have one element per index row (=1 in this case) df_input[k] = [coldata[k]] else: # scalar, don't wrap df_input[k] = coldata[k] ret = pd.DataFrame(df_input, index=pd.Index(name=self.id.name, data=id_index)) ret.name = self.name return ret def __get_selection_as_df(self, coldata): """Return a pandas dataframe for the given rows and columns with the id column as the index. This is used when multiple row indices are selected (or a list/array/slice of a single index is passed to get). __get_selection_as_df_single_row should be used if a single index is passed to get. :param coldata: dict mapping column names to values (list/arrays or dataframes) :type coldata: dict """ id_index = coldata.pop('id') df_input = OrderedDict() for k in coldata: # for each column if isinstance(coldata[k], np.ndarray) and coldata[k].ndim > 1: df_input[k] = list(coldata[k]) # convert multi-dim array to list of inner arrays elif isinstance(coldata[k], pd.DataFrame): # multiple rows were selected and collapsed into a dataframe # split up the rows of the df into a list of dataframes, one per row # TODO make this more efficient df_input[k] = [coldata[k].iloc[[i]] for i in range(len(coldata[k]))] else: df_input[k] = coldata[k] ret = pd.DataFrame(df_input, index=pd.Index(name=self.id.name, data=id_index)) ret.name = self.name return ret def __contains__(self, val): """ Check if the given value (i.e., column) exists in this table """ return val in self.__colids or val in self.__indices def get_foreign_columns(self): """ Determine the names of all columns that link to another DynamicTable, i.e., find all DynamicTableRegion type columns. Similar to a foreign key in a database, a DynamicTableRegion column references elements in another table. :returns: List of strings with the column names """ col_names = [] for col_index, col in enumerate(self.columns): if isinstance(col, DynamicTableRegion): col_names.append(col.name) return col_names def has_foreign_columns(self): """ Does the table contain DynamicTableRegion columns :returns: True if the table contains a DynamicTableRegion column, else False """ for col_index, col in enumerate(self.columns): if isinstance(col, DynamicTableRegion): return True return False @docval({'name': 'other_tables', 'type': (list, tuple, set), 'doc': "List of additional tables to consider in the search. Usually this " "parameter is used for internal purposes, e.g., when we need to " "consider AlignedDynamicTable", 'default': None}, allow_extra=False) def get_linked_tables(self, **kwargs): """ Get a list of the full list of all tables that are being linked to directly or indirectly from this table via foreign DynamicTableColumns included in this table or in any table that can be reached through DynamicTableRegion columns Returns: List of NamedTuple objects with: * 'source_table' : The source table containing the DynamicTableRegion column * 'source_column' : The relevant DynamicTableRegion column in the 'source_table' * 'target_table' : The target DynamicTable; same as source_column.table. """ link_type = NamedTuple('DynamicTableLink', [('source_table', DynamicTable), ('source_column', Union[DynamicTableRegion, VectorIndex]), ('target_table', DynamicTable)]) curr_tables = [self, ] # Set of tables other_tables = getargs('other_tables', kwargs) if other_tables is not None: curr_tables += other_tables curr_index = 0 foreign_cols = [] while curr_index < len(curr_tables): for col_index, col in enumerate(curr_tables[curr_index].columns): if isinstance(col, DynamicTableRegion): foreign_cols.append(link_type(source_table=curr_tables[curr_index], source_column=col, target_table=col.table)) curr_table_visited = False for t in curr_tables: if t is col.table: curr_table_visited = True if not curr_table_visited: curr_tables.append(col.table) curr_index += 1 return foreign_cols @docval({'name': 'exclude', 'type': set, 'doc': 'Set of column names to exclude from the dataframe', 'default': None}, {'name': 'index', 'type': bool, 'doc': ('Whether to return indices for a DynamicTableRegion column. If False, nested dataframes will be ' 'returned.'), 'default': False} ) def to_dataframe(self, **kwargs): """ Produce a pandas DataFrame containing this table's data. If this table contains a DynamicTableRegion, by default, If exclude is None, this is equivalent to table.get(slice(None, None, None), index=False). """ arg = slice(None, None, None) # select all rows sel = self.__get_selection_as_dict(arg, df=True, **kwargs) ret = self.__get_selection_as_df(sel) return ret @classmethod @docval( {'name': 'df', 'type': pd.DataFrame, 'doc': 'source DataFrame'}, {'name': 'name', 'type': str, 'doc': 'the name of this table'}, { 'name': 'index_column', 'type': str, 'doc': 'if provided, this column will become the table\'s index', 'default': None }, { 'name': 'table_description', 'type': str, 'doc': 'a description of what is in the resulting table', 'default': '' }, { 'name': 'columns', 'type': (list, tuple), 'doc': 'a list/tuple of dictionaries specifying columns in the table', 'default': None }, allow_extra=True ) def from_dataframe(cls, **kwargs): ''' Construct an instance of DynamicTable (or a subclass) from a pandas DataFrame. The columns of the resulting table are defined by the columns of the dataframe and the index by the dataframe's index (make sure it has a name!) or by a column whose name is supplied to the index_column parameter. We recommend that you supply *columns* - a list/tuple of dictionaries containing the name and description of the column- to help others understand the contents of your table. See :py:class:`~hdmf.common.table.DynamicTable` for more details on *columns*. ''' columns = kwargs.pop('columns') df = kwargs.pop('df') name = kwargs.pop('name') index_column = kwargs.pop('index_column') table_description = kwargs.pop('table_description') column_descriptions = kwargs.pop('column_descriptions', dict()) supplied_columns = dict() if columns: supplied_columns = {x['name']: x for x in columns} class_cols = {x['name']: x for x in cls.__columns__} required_cols = set(x['name'] for x in cls.__columns__ if 'required' in x and x['required']) df_cols = df.columns if required_cols - set(df_cols): raise ValueError('missing required cols: ' + str(required_cols - set(df_cols))) if set(supplied_columns.keys()) - set(df_cols): raise ValueError('cols specified but not provided: ' + str(set(supplied_columns.keys()) - set(df_cols))) columns = [] for col_name in df_cols: if col_name in class_cols: columns.append(class_cols[col_name]) elif col_name in supplied_columns: columns.append(supplied_columns[col_name]) else: columns.append({'name': col_name, 'description': column_descriptions.get(col_name, 'no description')}) if hasattr(df[col_name].iloc[0], '__len__') and not isinstance(df[col_name].iloc[0], str): lengths = [len(x) for x in df[col_name]] if not lengths[1:] == lengths[:-1]: columns[-1].update(index=True) if index_column is not None: ids = ElementIdentifiers(name=index_column, data=df[index_column].values.tolist()) else: index_name = df.index.name if df.index.name is not None else 'id' ids = ElementIdentifiers(name=index_name, data=df.index.values.tolist()) columns = cls.__build_columns(columns, df=df) return cls(name=name, id=ids, columns=columns, description=table_description, **kwargs) def copy(self): """ Return a copy of this DynamicTable. This is useful for linking. """ kwargs = dict(name=self.name, id=self.id, columns=self.columns, description=self.description, colnames=self.colnames) return self.__class__(**kwargs) @register_class('DynamicTableRegion') class DynamicTableRegion(VectorData): """ DynamicTableRegion provides a link from one table to an index or region of another. The `table` attribute is another `DynamicTable`, indicating which table is referenced. The data is int(s) indicating the row(s) (0-indexed) of the target array. `DynamicTableRegion`s can be used to associate multiple rows with the same meta-data without data duplication. They can also be used to create hierarchical relationships between multiple `DynamicTable`s. `DynamicTableRegion` objects may be paired with a `VectorIndex` object to create ragged references, so a single cell of a `DynamicTable` can reference many rows of another `DynamicTable`. """ __fields__ = ( 'table', ) @docval({'name': 'name', 'type': str, 'doc': 'the name of this VectorData'}, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'a dataset where the first dimension is a concatenation of multiple vectors'}, {'name': 'description', 'type': str, 'doc': 'a description of what this region represents'}, {'name': 'table', 'type': DynamicTable, 'doc': 'the DynamicTable this region applies to', 'default': None}) def __init__(self, **kwargs): t = popargs('table', kwargs) call_docval_func(super().__init__, kwargs) self.table = t @property def table(self): """The DynamicTable this DynamicTableRegion is pointing to""" return self.fields.get('table') @table.setter def table(self, val): """ Set the table this DynamicTableRegion should be pointing to :param val: The DynamicTable this DynamicTableRegion should be pointing to :raises: AttributeError if table is already in fields :raises: IndexError if the current indices are out of bounds for the new table given by val """ if val is None: return if 'table' in self.fields: msg = "can't set attribute 'table' -- already set" raise AttributeError(msg) dat = self.data if isinstance(dat, DataIO): dat = dat.data self.fields['table'] = val def __getitem__(self, arg): return self.get(arg) def get(self, arg, index=False, df=True, **kwargs): """ Subset the DynamicTableRegion :param arg: Key defining which elements of the table to select. This may be one of the following: 1) string with the name of the column to select 2) a tuple consisting of (int, str) where the int selects the row and the string identifies the column to select by name 3) int, list of ints, array, or slice selecting a set of full rows in the table. If an int is used, then scalars are returned for each column that has a single value. If a list, array, or slice is used and df=False, then lists are returned for each column, even if the list, array, or slice resolves to a single row. :param index: Boolean indicating whether to return indices of the DTR (default False) :param df: Boolean indicating whether to return the result as a pandas DataFrame (default True) :return: Result from self.table[...] with the appropriate selection based on the rows selected by this DynamicTableRegion """ if not df and not index: # returning nested lists of lists for DTRs and ragged DTRs is complicated and not yet supported raise ValueError('DynamicTableRegion.get() with df=False and index=False is not yet supported.') # treat the list of indices as data that can be indexed. then pass the # result to the table to get the data if isinstance(arg, tuple): arg1 = arg[0] arg2 = arg[1] return self.table[self.data[arg1], arg2] elif isinstance(arg, str): return self.table[arg] elif np.issubdtype(type(arg), np.integer): if arg >= len(self.data): raise IndexError('index {} out of bounds for data of length {}'.format(arg, len(self.data))) ret = self.data[arg] if not index: ret = self.table.get(ret, df=df, index=index, **kwargs) return ret elif isinstance(arg, (list, slice, np.ndarray)): idx = arg # get the data at the specified indices if isinstance(self.data, (tuple, list)) and isinstance(idx, (list, np.ndarray)): ret = [self.data[i] for i in idx] else: ret = self.data[idx] # dereference them if necessary if not index: # These lines are needed because indexing Dataset with a list/ndarray # of ints requires the list to be sorted. # # First get the unique elements, retrieve them from the table, and then # reorder the result according to the original index that the user passed in. # # When not returning a DataFrame, we need to recursively sort the subelements # of the list we are returning. This is carried out by the recursive method _index_lol uniq = np.unique(ret) lut = {val: i for i, val in enumerate(uniq)} values = self.table.get(uniq, df=df, index=index, **kwargs) if df: ret = values.iloc[[lut[i] for i in ret]] else: ret = self._index_lol(values, ret, lut) return ret else: raise ValueError("unrecognized argument: '%s'" % arg) def _index_lol(self, result, index, lut): """ This is a helper function for indexing a list of lists/ndarrays. When not returning a DataFrame, indexing a DynamicTable will return a list of lists and ndarrays. To sort the result of a DynamicTable index according to the order of the indices passed in by the user, we have to recursively sort the sub-lists/sub-ndarrays. """ ret = list() for col in result: if isinstance(col, list): if isinstance(col[0], list): # list of columns that need to be sorted ret.append(self._index_lol(col, index, lut)) else: # list of elements, one for each row to return ret.append([col[lut[i]] for i in index]) elif isinstance(col, np.ndarray): ret.append(np.array([col[lut[i]] for i in index], dtype=col.dtype)) else: raise ValueError('unrecognized column type: %s. Expected list or np.ndarray' % type(col)) return ret def to_dataframe(self, **kwargs): """ Convert the whole DynamicTableRegion to a pandas dataframe. Keyword arguments are passed through to the to_dataframe method of DynamicTable that is being referenced (i.e., self.table). This allows specification of the 'exclude' parameter and any other parameters of DynamicTable.to_dataframe. """ return self.table.to_dataframe(**kwargs).iloc[self.data[:]] @property def shape(self): """ Define the shape, i.e., (num_rows, num_columns) of the selected table region :return: Shape tuple with two integers indicating the number of rows and number of columns """ return (len(self.data), len(self.table.columns)) def __repr__(self): """ :return: Human-readable string representation of the DynamicTableRegion """ cls = self.__class__ template = "%s %s.%s at 0x%d\n" % (self.name, cls.__module__, cls.__name__, id(self)) template += " Target table: %s %s.%s at 0x%d\n" % (self.table.name, self.table.__class__.__module__, self.table.__class__.__name__, id(self.table)) return template def _uint_precision(elements): """ Calculate the uint precision needed to encode a set of elements """ n_elements = elements if hasattr(elements, '__len__'): n_elements = len(elements) return np.dtype('uint%d' % (8 * max(1, int((2 ** np.ceil((np.ceil(np.log2(n_elements)) - 8) / 8)))))).type def _map_elements(uint, elements): """ Map CV terms to their uint index """ return {t[1]: uint(t[0]) for t in enumerate(elements)} @register_class('EnumData', EXP_NAMESPACE) class EnumData(VectorData): """ A n-dimensional dataset that can contain elements from fixed set of elements. """ __fields__ = ('elements', ) @docval({'name': 'name', 'type': str, 'doc': 'the name of this column'}, {'name': 'description', 'type': str, 'doc': 'a description for this column'}, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'integers that index into elements for the value of each row', 'default': list()}, {'name': 'elements', 'type': ('array_data', 'data', VectorData), 'default': list(), 'doc': 'lookup values for each integer in ``data``'}) def __init__(self, **kwargs): elements = popargs('elements', kwargs) super().__init__(**kwargs) if not isinstance(elements, VectorData): elements = VectorData('%s_elements' % self.name, data=elements, description='fixed set of elements referenced by %s' % self.name) self.elements = elements if len(self.elements) > 0: self.__uint = _uint_precision(self.elements.data) self.__revidx = _map_elements(self.__uint, self.elements.data) else: self.__revidx = dict() # a map from term to index self.__uint = None # the precision needed to encode all terms def __add_term(self, term): """ Add a new CV term, and return it's corresponding index Returns: The index of the term """ if term not in self.__revidx: # get minimum uint precision needed for elements self.elements.append(term) uint = _uint_precision(self.elements) if self.__uint is uint: # add the new term to the index-term map self.__revidx[term] = self.__uint(len(self.elements) - 1) else: # remap terms to their uint and bump the precision of existing data self.__uint = uint self.__revidx = _map_elements(self.__uint, self.elements) for i in range(len(self.data)): self.data[i] = self.__uint(self.data[i]) return self.__revidx[term] def __getitem__(self, arg): return self.get(arg, index=False) def _get_helper(self, idx, index=False, join=False, **kwargs): """ A helper function for getting elements elements This helper function contains the post-processing of retrieve indices. By separating this, it allows customizing processing of indices before resolving the elements elements """ if index: return idx if not np.isscalar(idx): idx = np.asarray(idx) ret = np.asarray(self.elements.get(idx.ravel(), **kwargs)).reshape(idx.shape) if join: ret = ''.join(ret.ravel()) else: ret = self.elements.get(idx, **kwargs) return ret def get(self, arg, index=False, join=False, **kwargs): """ Return elements elements for the given argument. Args: index (bool): Return indices, do not return CV elements join (bool): Concatenate elements together into a single string Returns: CV elements if *join* is False or a concatenation of all selected elements if *join* is True. """ idx = self.data[arg] return self._get_helper(idx, index=index, join=join, **kwargs) @docval({'name': 'val', 'type': None, 'doc': 'the value to add to this column'}, {'name': 'index', 'type': bool, 'doc': 'whether or not the value being added is an index', 'default': False}) def add_row(self, **kwargs): """Append a data value to this EnumData column If an element is provided for *val* (i.e. *index* is False), the correct index value will be determined. Otherwise, *val* will be added as provided. """ val, index = getargs('val', 'index', kwargs) if not index: val = self.__add_term(val) super().append(val) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/container.py0000644000655200065520000013021400000000000016623 0ustar00circlecicircleciimport types from abc import abstractmethod from collections import OrderedDict from copy import deepcopy from uuid import uuid4 from warnings import warn import h5py import numpy as np import pandas as pd from .data_utils import DataIO, append_data, extend_data from .utils import (docval, get_docval, call_docval_func, getargs, ExtenderMeta, get_data_shape, fmt_docval_args, popargs, LabelledDict) def _set_exp(cls): """Set a class as being experimental""" cls._experimental = True def _exp_warn_msg(cls): """Generate a warning message experimental features""" pfx = cls if isinstance(cls, type): pfx = cls.__name__ msg = ('%s is experimental -- it may be removed in the future and ' 'is not guaranteed to maintain backward compatibility') % pfx return msg class AbstractContainer(metaclass=ExtenderMeta): # The name of the class attribute that subclasses use to autogenerate properties # This parameterization is supplied in case users would like to configure # the class attribute name to something domain-specific _experimental = False _fieldsname = '__fields__' _data_type_attr = 'data_type' # Subclasses use this class attribute to add properties to autogenerate # Autogenerated properties will store values in self.__field_values __fields__ = tuple() # This field is automatically set by __gather_fields before initialization. # It holds all the values in __fields__ for this class and its parent classes. __fieldsconf = tuple() _pconf_allowed_keys = {'name', 'doc', 'settable'} # Override the _setter factor function, so directives that apply to # Container do not get used on Data @classmethod def _setter(cls, field): """ Make a setter function for creating a :py:func:`property` """ name = field['name'] if not field.get('settable', True): return None def setter(self, val): if val is None: return if name in self.fields: msg = "can't set attribute '%s' -- already set" % name raise AttributeError(msg) self.fields[name] = val return setter @classmethod def _getter(cls, field): """ Make a getter function for creating a :py:func:`property` """ doc = field.get('doc') name = field['name'] def getter(self): return self.fields.get(name) setattr(getter, '__doc__', doc) return getter @staticmethod def _check_field_spec(field): """ A helper function for __gather_fields to make sure we are always working with a dict specification and that the specification contains the correct keys """ tmp = field if isinstance(tmp, dict): if 'name' not in tmp: raise ValueError("must specify 'name' if using dict in __fields__") else: tmp = {'name': tmp} return tmp @classmethod def _check_field_spec_keys(cls, field_conf): for k in field_conf: if k not in cls._pconf_allowed_keys: msg = ("Unrecognized key '%s' in %s config '%s' on %s" % (k, cls._fieldsname, field_conf['name'], cls.__name__)) raise ValueError(msg) @classmethod def _get_fields(cls): return getattr(cls, cls._fieldsname) @classmethod def _set_fields(cls, value): return setattr(cls, cls._fieldsname, value) @classmethod def get_fields_conf(cls): return cls.__fieldsconf @ExtenderMeta.pre_init def __gather_fields(cls, name, bases, classdict): ''' This classmethod will be called during class declaration in the metaclass to automatically create setters and getters for fields that need to be exported ''' fields = cls._get_fields() if not isinstance(fields, tuple): msg = "'%s' must be of type tuple" % cls._fieldsname raise TypeError(msg) # check field specs and create map from field name to field conf dictionary fields_dict = OrderedDict() for f in fields: pconf = cls._check_field_spec(f) cls._check_field_spec_keys(pconf) fields_dict[pconf['name']] = pconf all_fields_conf = list(fields_dict.values()) # check whether this class overrides __fields__ if len(bases): # find highest base class that is an AbstractContainer (parent is higher than children) base_cls = None for base_cls in reversed(bases): if issubclass(base_cls, AbstractContainer): break base_fields = base_cls._get_fields() # tuple of field names from base class if base_fields is not fields: # check whether new fields spec already exists in base class fields_to_remove_from_base = list() for field_name in fields_dict: if field_name in base_fields: fields_to_remove_from_base.append(field_name) # prepend field specs from base class to fields list of this class # but only field specs that are not redefined in this class base_fields_conf = base_cls.get_fields_conf() # tuple of fields configurations from base class base_fields_conf_to_add = list() for pconf in base_fields_conf: if pconf['name'] not in fields_to_remove_from_base: base_fields_conf_to_add.append(pconf) all_fields_conf[0:0] = base_fields_conf_to_add # create getter and setter if attribute does not already exist # if 'doc' not specified in __fields__, use doc from docval of __init__ docs = {dv['name']: dv['doc'] for dv in get_docval(cls.__init__)} for field_conf in all_fields_conf: pname = field_conf['name'] field_conf.setdefault('doc', docs.get(pname)) if not hasattr(cls, pname): setattr(cls, pname, property(cls._getter(field_conf), cls._setter(field_conf))) cls._set_fields(tuple(field_conf['name'] for field_conf in all_fields_conf)) cls.__fieldsconf = tuple(all_fields_conf) def __new__(cls, *args, **kwargs): inst = super().__new__(cls) if cls._experimental: warn(_exp_warn_msg(cls)) inst.__container_source = kwargs.pop('container_source', None) inst.__parent = None inst.__children = list() inst.__modified = True inst.__object_id = kwargs.pop('object_id', str(uuid4())) inst.parent = kwargs.pop('parent', None) return inst @docval({'name': 'name', 'type': str, 'doc': 'the name of this container'}) def __init__(self, **kwargs): name = getargs('name', kwargs) if '/' in name: raise ValueError("name '" + name + "' cannot contain '/'") self.__name = name self.__field_values = dict() @property def name(self): ''' The name of this Container ''' return self.__name @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to search for', 'default': None}) def get_ancestor(self, **kwargs): """ Traverse parent hierarchy and return first instance of the specified data_type """ data_type = getargs('data_type', kwargs) if data_type is None: return self.parent p = self.parent while p is not None: if getattr(p, p._data_type_attr) == data_type: return p p = p.parent return None @property def fields(self): return self.__field_values @property def object_id(self): if self.__object_id is None: self.__object_id = str(uuid4()) return self.__object_id @docval({'name': 'recurse', 'type': bool, 'doc': "whether or not to change the object ID of this container's children", 'default': True}) def generate_new_id(self, **kwargs): """Changes the object ID of this Container and all of its children to a new UUID string.""" recurse = getargs('recurse', kwargs) self.__object_id = str(uuid4()) self.set_modified() if recurse: for c in self.children: c.generate_new_id(**kwargs) @property def modified(self): return self.__modified @docval({'name': 'modified', 'type': bool, 'doc': 'whether or not this Container has been modified', 'default': True}) def set_modified(self, **kwargs): modified = getargs('modified', kwargs) self.__modified = modified if modified and isinstance(self.parent, Container): self.parent.set_modified() @property def children(self): return tuple(self.__children) @docval({'name': 'child', 'type': 'Container', 'doc': 'the child Container for this Container', 'default': None}) def add_child(self, **kwargs): warn(DeprecationWarning('add_child is deprecated. Set the parent attribute instead.')) child = getargs('child', kwargs) if child is not None: # if child.parent is a Container, then the mismatch between child.parent and parent # is used to make a soft/external link from the parent to a child elsewhere # if child.parent is not a Container, it is either None or a Proxy and should be set to self if not isinstance(child.parent, AbstractContainer): # actually add the child to the parent in parent setter child.parent = self else: warn('Cannot add None as child to a container %s' % self.name) @classmethod def type_hierarchy(cls): return cls.__mro__ @property def container_source(self): ''' The source of this Container ''' return self.__container_source @container_source.setter def container_source(self, source): if self.__container_source is not None: raise Exception('cannot reassign container_source') self.__container_source = source @property def parent(self): ''' The parent Container of this Container ''' # do it this way because __parent may not exist yet (not set in constructor) return getattr(self, '_AbstractContainer__parent', None) @parent.setter def parent(self, parent_container): if self.parent is parent_container: return if self.parent is not None: if isinstance(self.parent, AbstractContainer): raise ValueError(('Cannot reassign parent to Container: %s. ' 'Parent is already: %s.' % (repr(self), repr(self.parent)))) else: if parent_container is None: raise ValueError("Got None for parent of '%s' - cannot overwrite Proxy with NoneType" % repr(self)) # NOTE this assumes isinstance(parent_container, Proxy) but we get a circular import # if we try to do that if self.parent.matches(parent_container): self.__parent = parent_container parent_container.__children.append(self) parent_container.set_modified() else: self.__parent.add_candidate(parent_container) else: self.__parent = parent_container if isinstance(parent_container, Container): parent_container.__children.append(self) parent_container.set_modified() def _remove_child(self, child): """Remove a child Container. Intended for use in subclasses that allow dynamic addition of child Containers.""" if not isinstance(child, AbstractContainer): raise ValueError('Cannot remove non-AbstractContainer object from children.') if child not in self.children: raise ValueError("%s '%s' is not a child of %s '%s'." % (child.__class__.__name__, child.name, self.__class__.__name__, self.name)) child.__parent = None self.__children.remove(child) child.set_modified() self.set_modified() class Container(AbstractContainer): """A container that can contain other containers and has special functionality for printing.""" _pconf_allowed_keys = {'name', 'child', 'required_name', 'doc', 'settable'} @classmethod def _setter(cls, field): """Returns a list of setter functions for the given field to be added to the class during class declaration.""" super_setter = AbstractContainer._setter(field) ret = [super_setter] # create setter with check for required name if field.get('required_name', None) is not None: name = field['required_name'] idx1 = len(ret) - 1 def container_setter(self, val): if val is not None: if not isinstance(val, AbstractContainer): msg = ("Field '%s' on %s has a required name and must be a subclass of AbstractContainer." % (field['name'], self.__class__.__name__)) raise ValueError(msg) if val.name != name: msg = ("Field '%s' on %s must be named '%s'." % (field['name'], self.__class__.__name__, name)) raise ValueError(msg) ret[idx1](self, val) ret.append(container_setter) # create setter that accepts a value or tuple, list, or dict or values and sets the value's parent to self if field.get('child', False): idx2 = len(ret) - 1 def container_setter(self, val): ret[idx2](self, val) if val is not None: if isinstance(val, (tuple, list)): pass elif isinstance(val, dict): val = val.values() else: val = [val] for v in val: if not isinstance(v.parent, Container): v.parent = self # else, the ObjectMapper will create a link from self (parent) to v (child with existing # parent) ret.append(container_setter) return ret[-1] def __repr__(self): cls = self.__class__ template = "%s %s.%s at 0x%d" % (self.name, cls.__module__, cls.__name__, id(self)) if len(self.fields): template += "\nFields:\n" for k in sorted(self.fields): # sorted to enable tests v = self.fields[k] # if isinstance(v, DataIO) or not hasattr(v, '__len__') or len(v) > 0: if hasattr(v, '__len__'): if isinstance(v, (np.ndarray, list, tuple)): if len(v) > 0: template += " {}: {}\n".format(k, self.__smart_str(v, 1)) elif v: template += " {}: {}\n".format(k, self.__smart_str(v, 1)) else: template += " {}: {}\n".format(k, v) return template @staticmethod def __smart_str(v, num_indent): """ Print compact string representation of data. If v is a list, try to print it using numpy. This will condense the string representation of datasets with many elements. If that doesn't work, just print the list. If v is a dictionary, print the name and type of each element If v is a set, print it sorted If v is a neurodata_type, print the name of type Otherwise, use the built-in str() Parameters ---------- v Returns ------- str """ if isinstance(v, list) or isinstance(v, tuple): if len(v) and isinstance(v[0], AbstractContainer): return Container.__smart_str_list(v, num_indent, '(') try: return str(np.asarray(v)) except ValueError: return Container.__smart_str_list(v, num_indent, '(') elif isinstance(v, dict): return Container.__smart_str_dict(v, num_indent) elif isinstance(v, set): return Container.__smart_str_list(sorted(list(v)), num_indent, '{') elif isinstance(v, AbstractContainer): return "{} {}".format(getattr(v, 'name'), type(v)) else: return str(v) @staticmethod def __smart_str_list(str_list, num_indent, left_br): if left_br == '(': right_br = ')' if left_br == '{': right_br = '}' if len(str_list) == 0: return left_br + ' ' + right_br indent = num_indent * 2 * ' ' indent_in = (num_indent + 1) * 2 * ' ' out = left_br for v in str_list[:-1]: out += '\n' + indent_in + Container.__smart_str(v, num_indent + 1) + ',' if str_list: out += '\n' + indent_in + Container.__smart_str(str_list[-1], num_indent + 1) out += '\n' + indent + right_br return out @staticmethod def __smart_str_dict(d, num_indent): left_br = '{' right_br = '}' if len(d) == 0: return left_br + ' ' + right_br indent = num_indent * 2 * ' ' indent_in = (num_indent + 1) * 2 * ' ' out = left_br keys = sorted(list(d.keys())) for k in keys[:-1]: out += '\n' + indent_in + Container.__smart_str(k, num_indent + 1) + ' ' + str(type(d[k])) + ',' if keys: out += '\n' + indent_in + Container.__smart_str(keys[-1], num_indent + 1) + ' ' + str(type(d[keys[-1]])) out += '\n' + indent + right_br return out class Data(AbstractContainer): """ A class for representing dataset containers """ @docval({'name': 'name', 'type': str, 'doc': 'the name of this container'}, {'name': 'data', 'type': ('scalar_data', 'array_data', 'data'), 'doc': 'the source of the data'}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) self.__data = getargs('data', kwargs) @property def data(self): return self.__data @property def shape(self): """ Get the shape of the data represented by this container :return: Shape tuple :rtype: tuple of ints """ return get_data_shape(self.__data) @docval({'name': 'dataio', 'type': DataIO, 'doc': 'the DataIO to apply to the data held by this Data'}) def set_dataio(self, **kwargs): """ Apply DataIO object to the data held by this Data object """ dataio = getargs('dataio', kwargs) dataio.data = self.__data self.__data = dataio @docval({'name': 'func', 'type': types.FunctionType, 'doc': 'a function to transform *data*'}) def transform(self, **kwargs): """ Transform data from the current underlying state. This function can be used to permanently load data from disk, or convert to a different representation, such as a torch.Tensor """ func = getargs('func', kwargs) self.__data = func(self.__data) return self def __bool__(self): if self.data is not None: if isinstance(self.data, (np.ndarray, tuple, list)): return len(self.data) != 0 if self.data: return True return False def __len__(self): return len(self.__data) def __getitem__(self, args): return self.get(args) def get(self, args): if isinstance(self.data, (tuple, list)) and isinstance(args, (tuple, list, np.ndarray)): return [self.data[i] for i in args] if isinstance(self.data, h5py.Dataset) and isinstance(args, np.ndarray): # This is needed for h5py 2.9 compatability args = args.tolist() return self.data[args] def append(self, arg): self.__data = append_data(self.__data, arg) def extend(self, arg): """ The extend_data method adds all the elements of the iterable arg to the end of the data of this Data container. :param arg: The iterable to add to the end of this VectorData """ self.__data = extend_data(self.__data, arg) class DataRegion(Data): @property @abstractmethod def data(self): ''' The target data that this region applies to ''' pass @property @abstractmethod def region(self): ''' The region that indexes into data e.g. slice or list of indices ''' pass def _not_parent(arg): return arg['name'] != 'parent' class MultiContainerInterface(Container): """Class that dynamically defines methods to support a Container holding multiple Containers of the same type. To use, extend this class and create a dictionary as a class attribute with any of the following keys: * 'attr' to name the attribute that stores the Container instances * 'type' to provide the Container object type (type or list/tuple of types, type can be a docval macro) * 'add' to name the method for adding Container instances * 'get' to name the method for getting Container instances * 'create' to name the method for creating Container instances (only if a single type is specified) If the attribute does not exist in the class, it will be generated. If it does exist, it should behave like a dict. The keys 'attr', 'type', and 'add' are required. """ def __new__(cls, *args, **kwargs): if cls is MultiContainerInterface: raise TypeError("Can't instantiate class MultiContainerInterface.") if not hasattr(cls, '__clsconf__'): raise TypeError("MultiContainerInterface subclass %s is missing __clsconf__ attribute. Please check that " "the class is properly defined." % cls.__name__) return super().__new__(cls, *args, **kwargs) @staticmethod def __add_article(noun): if isinstance(noun, tuple): noun = noun[0] if isinstance(noun, type): noun = noun.__name__ if noun[0] in ('aeiouAEIOU'): return 'an %s' % noun return 'a %s' % noun @staticmethod def __join(argtype): """Return a grammatical string representation of a list or tuple of classes or text. Examples: cls.__join(Container) returns "Container" cls.__join((Container, )) returns "Container" cls.__join((Container, Data)) returns "Container or Data" cls.__join((Container, Data, Subcontainer)) returns "Container, Data, or Subcontainer" """ def tostr(x): return x.__name__ if isinstance(x, type) else x if isinstance(argtype, (list, tuple)): args_str = [tostr(x) for x in argtype] if len(args_str) == 1: return args_str[0] if len(args_str) == 2: return " or ".join(tostr(x) for x in args_str) else: return ", ".join(tostr(x) for x in args_str[:-1]) + ', or ' + args_str[-1] else: return tostr(argtype) @classmethod def __make_get(cls, func_name, attr_name, container_type): doc = "Get %s from this %s" % (cls.__add_article(container_type), cls.__name__) @docval({'name': 'name', 'type': str, 'doc': 'the name of the %s' % cls.__join(container_type), 'default': None}, rtype=container_type, returns='the %s with the given name' % cls.__join(container_type), func_name=func_name, doc=doc) def _func(self, **kwargs): name = getargs('name', kwargs) d = getattr(self, attr_name) ret = None if name is None: if len(d) > 1: msg = ("More than one element in %s of %s '%s' -- must specify a name." % (attr_name, cls.__name__, self.name)) raise ValueError(msg) elif len(d) == 0: msg = "%s of %s '%s' is empty." % (attr_name, cls.__name__, self.name) raise ValueError(msg) else: # only one item in dict for v in d.values(): ret = v else: ret = d.get(name) if ret is None: msg = "'%s' not found in %s of %s '%s'." % (name, attr_name, cls.__name__, self.name) raise KeyError(msg) return ret return _func @classmethod def __make_getitem(cls, attr_name, container_type): doc = "Get %s from this %s" % (cls.__add_article(container_type), cls.__name__) @docval({'name': 'name', 'type': str, 'doc': 'the name of the %s' % cls.__join(container_type), 'default': None}, rtype=container_type, returns='the %s with the given name' % cls.__join(container_type), func_name='__getitem__', doc=doc) def _func(self, **kwargs): # NOTE this is the same code as the getter but with different error messages name = getargs('name', kwargs) d = getattr(self, attr_name) ret = None if name is None: if len(d) > 1: msg = ("More than one %s in %s '%s' -- must specify a name." % (cls.__join(container_type), cls.__name__, self.name)) raise ValueError(msg) elif len(d) == 0: msg = "%s '%s' is empty." % (cls.__name__, self.name) raise ValueError(msg) else: # only one item in dict for v in d.values(): ret = v else: ret = d.get(name) if ret is None: msg = "'%s' not found in %s '%s'." % (name, cls.__name__, self.name) raise KeyError(msg) return ret return _func @classmethod def __make_add(cls, func_name, attr_name, container_type): doc = "Add %s to this %s" % (cls.__add_article(container_type), cls.__name__) @docval({'name': attr_name, 'type': (list, tuple, dict, container_type), 'doc': 'the %s to add' % cls.__join(container_type)}, func_name=func_name, doc=doc) def _func(self, **kwargs): container = getargs(attr_name, kwargs) if isinstance(container, container_type): containers = [container] elif isinstance(container, dict): containers = container.values() else: containers = container d = getattr(self, attr_name) for tmp in containers: if not isinstance(tmp.parent, Container): tmp.parent = self # else, the ObjectMapper will create a link from self (parent) to tmp (child with existing parent) if tmp.name in d: msg = "'%s' already exists in %s '%s'" % (tmp.name, cls.__name__, self.name) raise ValueError(msg) d[tmp.name] = tmp return container return _func @classmethod def __make_create(cls, func_name, add_name, container_type): doc = "Create %s and add it to this %s" % (cls.__add_article(container_type), cls.__name__) @docval(*filter(_not_parent, get_docval(container_type.__init__)), func_name=func_name, doc=doc, returns="the %s object that was created" % cls.__join(container_type), rtype=container_type) def _func(self, **kwargs): cargs, ckwargs = fmt_docval_args(container_type.__init__, kwargs) ret = container_type(*cargs, **ckwargs) getattr(self, add_name)(ret) return ret return _func @classmethod def __make_constructor(cls, clsconf): args = list() for conf in clsconf: attr_name = conf['attr'] container_type = conf['type'] args.append({'name': attr_name, 'type': (list, tuple, dict, container_type), 'doc': '%s to store in this interface' % cls.__join(container_type), 'default': dict()}) args.append({'name': 'name', 'type': str, 'doc': 'the name of this container', 'default': cls.__name__}) @docval(*args, func_name='__init__') def _func(self, **kwargs): call_docval_func(super(cls, self).__init__, kwargs) for conf in clsconf: attr_name = conf['attr'] add_name = conf['add'] container = popargs(attr_name, kwargs) add = getattr(self, add_name) add(container) return _func @classmethod def __make_getter(cls, attr): """Make a getter function for creating a :py:func:`property`""" def _func(self): # initialize the field to an empty labeled dict if it has not yet been # do this here to avoid creating default __init__ which may or may not be overridden in # custom classes and dynamically generated classes if attr not in self.fields: def _remove_child(child): if child.parent is self: self._remove_child(child) self.fields[attr] = LabelledDict(attr, remove_callable=_remove_child) return self.fields.get(attr) return _func @classmethod def __make_setter(cls, add_name): """Make a setter function for creating a :py:func:`property`""" @docval({'name': 'val', 'type': (list, tuple, dict), 'doc': 'the sub items to add', 'default': None}) def _func(self, **kwargs): val = getargs('val', kwargs) if val is None: return getattr(self, add_name)(val) return _func @ExtenderMeta.pre_init def __build_class(cls, name, bases, classdict): """Verify __clsconf__ and create methods based on __clsconf__. This method is called prior to __new__ and __init__ during class declaration in the metaclass. """ if not hasattr(cls, '__clsconf__'): return multi = False if isinstance(cls.__clsconf__, dict): clsconf = [cls.__clsconf__] elif isinstance(cls.__clsconf__, list): multi = True clsconf = cls.__clsconf__ else: raise TypeError("'__clsconf__' for MultiContainerInterface subclass %s must be a dict or a list of " "dicts." % cls.__name__) for conf_index, conf_dict in enumerate(clsconf): cls.__build_conf_methods(conf_dict, conf_index, multi) # make __getitem__ (square bracket access) only if one conf type is defined if len(clsconf) == 1: attr = clsconf[0].get('attr') container_type = clsconf[0].get('type') setattr(cls, '__getitem__', cls.__make_getitem(attr, container_type)) # create the constructor, only if it has not been overridden # i.e. it is the same method as the parent class constructor if '__init__' not in cls.__dict__: setattr(cls, '__init__', cls.__make_constructor(clsconf)) @classmethod def __build_conf_methods(cls, conf_dict, conf_index, multi): # get add method name add = conf_dict.get('add') if add is None: msg = "MultiContainerInterface subclass %s is missing 'add' key in __clsconf__" % cls.__name__ if multi: msg += " at index %d" % conf_index raise ValueError(msg) # get container attribute name attr = conf_dict.get('attr') if attr is None: msg = "MultiContainerInterface subclass %s is missing 'attr' key in __clsconf__" % cls.__name__ if multi: msg += " at index %d" % conf_index raise ValueError(msg) # get container type container_type = conf_dict.get('type') if container_type is None: msg = "MultiContainerInterface subclass %s is missing 'type' key in __clsconf__" % cls.__name__ if multi: msg += " at index %d" % conf_index raise ValueError(msg) # create property with the name given in 'attr' only if the attribute is not already defined if not hasattr(cls, attr): getter = cls.__make_getter(attr) setter = cls.__make_setter(add) doc = "a dictionary containing the %s in this %s" % (cls.__join(container_type), cls.__name__) setattr(cls, attr, property(getter, setter, None, doc)) # create the add method setattr(cls, add, cls.__make_add(add, attr, container_type)) # create the create method, only if a single container type is specified create = conf_dict.get('create') if create is not None: if isinstance(container_type, type): setattr(cls, create, cls.__make_create(create, add, container_type)) else: msg = ("Cannot specify 'create' key in __clsconf__ for MultiContainerInterface subclass %s " "when 'type' key is not a single type") % cls.__name__ if multi: msg += " at index %d" % conf_index raise ValueError(msg) # create the get method get = conf_dict.get('get') if get is not None: setattr(cls, get, cls.__make_get(get, attr, container_type)) class Row(object, metaclass=ExtenderMeta): """ A class for representing rows from a Table. The Table class can be indicated with the __table__. Doing so will set constructor arguments for the Row class and ensure that Row.idx is set appropriately when a Row is added to the Table. It will also add functionality to the Table class for getting Row objects. Note, the Row class is not needed for working with Table objects. This is merely convenience functionality for working with Tables. """ __table__ = None @property def idx(self): """The index of this row in its respective Table""" return self.__idx @idx.setter def idx(self, val): if self.__idx is None: self.__idx = val else: raise ValueError("cannot reset the ID of a row object") @property def table(self): """The Table this Row comes from""" return self.__table @table.setter def table(self, val): if val is not None: self.__table = val if self.idx is None: self.idx = self.__table.add_row(**self.todict()) @ExtenderMeta.pre_init def __build_row_class(cls, name, bases, classdict): table_cls = getattr(cls, '__table__', None) if table_cls is not None: columns = getattr(table_cls, '__columns__') if cls.__init__ == bases[-1].__init__: # check if __init__ is overridden columns = deepcopy(columns) func_args = list() for col in columns: func_args.append(col) func_args.append({'name': 'table', 'type': Table, 'default': None, 'help': 'the table this row is from'}) func_args.append({'name': 'idx', 'type': int, 'default': None, 'help': 'the index for this row'}) @docval(*func_args) def __init__(self, **kwargs): super(cls, self).__init__() table, idx = popargs('table', 'idx', kwargs) self.__keys = list() self.__idx = None self.__table = None for k, v in kwargs.items(): self.__keys.append(k) setattr(self, k, v) self.idx = idx self.table = table setattr(cls, '__init__', __init__) def todict(self): return {k: getattr(self, k) for k in self.__keys} setattr(cls, 'todict', todict) # set this so Table.row gets set when a Table is instantiated table_cls.__rowclass__ = cls else: if bases != (object,): raise ValueError('__table__ must be set if sub-classing Row') def __eq__(self, other): return self.idx == other.idx and self.table is other.table class RowGetter: """ A simple class for providing __getitem__ functionality that returns Row objects to a Table. """ def __init__(self, table): self.table = table self.cache = dict() def __getitem__(self, idx): ret = self.cache.get(idx) if ret is None: row = self.table[idx] ret = self.table.__rowclass__(*row, table=self.table, idx=idx) self.cache[idx] = ret return ret class Table(Data): r''' Subclasses should specify the class attribute \_\_columns\_\_. This should be a list of dictionaries with the following keys: - ``name`` the column name - ``type`` the type of data in this column - ``doc`` a brief description of what gets stored in this column For reference, this list of dictionaries will be used with docval to autogenerate the ``add_row`` method for adding data to this table. If \_\_columns\_\_ is not specified, no custom ``add_row`` method will be added. The class attribute __defaultname__ can also be set to specify a default name for the table class. If \_\_defaultname\_\_ is not specified, then ``name`` will need to be specified when the class is instantiated. A Table class can be paired with a Row class for conveniently working with rows of a Table. This pairing must be indicated in the Row class implementation. See Row for more details. ''' # This class attribute is used to indicate which Row class should be used when # adding RowGetter functionality to the Table. __rowclass__ = None @ExtenderMeta.pre_init def __build_table_class(cls, name, bases, classdict): if hasattr(cls, '__columns__'): columns = getattr(cls, '__columns__') idx = dict() for i, col in enumerate(columns): idx[col['name']] = i setattr(cls, '__colidx__', idx) if cls.__init__ == bases[-1].__init__: # check if __init__ is overridden name = {'name': 'name', 'type': str, 'doc': 'the name of this table'} defname = getattr(cls, '__defaultname__', None) if defname is not None: name['default'] = defname @docval(name, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'the data in this table', 'default': list()}) def __init__(self, **kwargs): name, data = getargs('name', 'data', kwargs) colnames = [i['name'] for i in columns] super(cls, self).__init__(colnames, name, data) setattr(cls, '__init__', __init__) if cls.add_row == bases[-1].add_row: # check if add_row is overridden @docval(*columns) def add_row(self, **kwargs): return super(cls, self).add_row(kwargs) setattr(cls, 'add_row', add_row) @docval({'name': 'columns', 'type': (list, tuple), 'doc': 'a list of the columns in this table'}, {'name': 'name', 'type': str, 'doc': 'the name of this container'}, {'name': 'data', 'type': ('array_data', 'data'), 'doc': 'the source of the data', 'default': list()}) def __init__(self, **kwargs): self.__columns = tuple(popargs('columns', kwargs)) self.__col_index = {name: idx for idx, name in enumerate(self.__columns)} if getattr(self, '__rowclass__') is not None: self.row = RowGetter(self) call_docval_func(super(Table, self).__init__, kwargs) @property def columns(self): return self.__columns @docval({'name': 'values', 'type': dict, 'doc': 'the values for each column'}) def add_row(self, **kwargs): values = getargs('values', kwargs) if not isinstance(self.data, list): msg = 'Cannot append row to %s' % type(self.data) raise ValueError(msg) ret = len(self.data) row = [values[col] for col in self.columns] row = [v.idx if isinstance(v, Row) else v for v in row] self.data.append(tuple(row)) return ret def which(self, **kwargs): ''' Query a table ''' if len(kwargs) != 1: raise ValueError("only one column can be queried") colname, value = kwargs.popitem() idx = self.__colidx__.get(colname) if idx is None: msg = "no '%s' column in %s" % (colname, self.__class__.__name__) raise KeyError(msg) ret = list() for i in range(len(self.data)): row = self.data[i] row_val = row[idx] if row_val == value: ret.append(i) return ret def __len__(self): return len(self.data) def __getitem__(self, args): idx = args col = None if isinstance(args, tuple): idx = args[1] if isinstance(args[0], str): col = self.__col_index.get(args[0]) elif isinstance(args[0], int): col = args[0] else: raise KeyError('first argument must be a column name or index') return self.data[idx][col] elif isinstance(args, str): col = self.__col_index.get(args) if col is None: raise KeyError(args) return [row[col] for row in self.data] else: return self.data[idx] def to_dataframe(self): '''Produce a pandas DataFrame containing this table's data. ''' data = {colname: self[colname] for ii, colname in enumerate(self.columns)} return pd.DataFrame(data) @classmethod @docval( {'name': 'df', 'type': pd.DataFrame, 'doc': 'input data'}, {'name': 'name', 'type': str, 'doc': 'the name of this container', 'default': None}, { 'name': 'extra_ok', 'type': bool, 'doc': 'accept (and ignore) unexpected columns on the input dataframe', 'default': False }, ) def from_dataframe(cls, **kwargs): '''Construct an instance of Table (or a subclass) from a pandas DataFrame. The columns of the dataframe should match the columns defined on the Table subclass. ''' df, name, extra_ok = getargs('df', 'name', 'extra_ok', kwargs) cls_cols = list([col['name'] for col in getattr(cls, '__columns__')]) df_cols = list(df.columns) missing_columns = set(cls_cols) - set(df_cols) extra_columns = set(df_cols) - set(cls_cols) if extra_columns: raise ValueError( 'unrecognized column(s) {} for table class {} (columns {})'.format( extra_columns, cls.__name__, cls_cols ) ) use_index = False if len(missing_columns) == 1 and list(missing_columns)[0] == df.index.name: use_index = True elif missing_columns: raise ValueError( 'missing column(s) {} for table class {} (columns {}, provided {})'.format( missing_columns, cls.__name__, cls_cols, df_cols ) ) data = [] for index, row in df.iterrows(): if use_index: data.append([ row[colname] if colname != df.index.name else index for colname in cls_cols ]) else: data.append(tuple([row[colname] for colname in cls_cols])) if name is None: return cls(data=data) return cls(name=name, data=data) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/data_utils.py0000644000655200065520000010106100000000000016770 0ustar00circlecicircleciimport copy from abc import ABCMeta, abstractmethod from collections.abc import Iterable from warnings import warn import h5py import numpy as np from .utils import docval, getargs, popargs, docval_macro, get_data_shape def append_data(data, arg): if isinstance(data, (list, DataIO)): data.append(arg) return data elif isinstance(data, np.ndarray): return np.append(data, np.expand_dims(arg, axis=0), axis=0) elif isinstance(data, h5py.Dataset): shape = list(data.shape) shape[0] += 1 data.resize(shape) data[-1] = arg return data else: msg = "Data cannot append to object of type '%s'" % type(data) raise ValueError(msg) def extend_data(data, arg): """Add all the elements of the iterable arg to the end of data. :param data: The array to extend :type data: list, DataIO, np.ndarray, h5py.Dataset """ if isinstance(data, (list, DataIO)): data.extend(arg) return data elif isinstance(data, np.ndarray): return np.vstack((data, arg)) elif isinstance(data, h5py.Dataset): shape = list(data.shape) shape[0] += len(arg) data.resize(shape) data[-len(arg):] = arg return data else: msg = "Data cannot extend object of type '%s'" % type(data) raise ValueError(msg) @docval_macro('array_data') class AbstractDataChunkIterator(metaclass=ABCMeta): """ Abstract iterator class used to iterate over DataChunks. Derived classes must ensure that all abstract methods and abstract properties are implemented, in particular, dtype, maxshape, __iter__, ___next__, recommended_chunk_shape, and recommended_data_shape. Iterating over AbstractContainer objects is not yet supported. """ @abstractmethod def __iter__(self): """Return the iterator object""" raise NotImplementedError("__iter__ not implemented for derived class") @abstractmethod def __next__(self): r""" Return the next data chunk or raise a StopIteration exception if all chunks have been retrieved. HINT: numpy.s\_ provides a convenient way to generate index tuples using standard array slicing. This is often useful to define the DataChunk.selection of the current chunk :returns: DataChunk object with the data and selection of the current chunk :rtype: DataChunk """ raise NotImplementedError("__next__ not implemented for derived class") @abstractmethod def recommended_chunk_shape(self): """ Recommend the chunk shape for the data array. :return: NumPy-style shape tuple describing the recommended shape for the chunks of the target array or None. This may or may not be the same as the shape of the chunks returned in the iteration process. """ raise NotImplementedError("recommended_chunk_shape not implemented for derived class") @abstractmethod def recommended_data_shape(self): """ Recommend the initial shape for the data array. This is useful in particular to avoid repeated resized of the target array when reading from this data iterator. This should typically be either the final size of the array or the known minimal shape of the array. :return: NumPy-style shape tuple indicating the recommended initial shape for the target array. This may or may not be the final full shape of the array, i.e., the array is allowed to grow. This should not be None. """ raise NotImplementedError("recommended_data_shape not implemented for derived class") @property @abstractmethod def dtype(self): """ Define the data type of the array :return: NumPy style dtype or otherwise compliant dtype string """ raise NotImplementedError("dtype not implemented for derived class") @property @abstractmethod def maxshape(self): """ Property describing the maximum shape of the data array that is being iterated over :return: NumPy-style shape tuple indicating the maxiumum dimensions up to which the dataset may be resized. Axes with None are unlimited. """ raise NotImplementedError("maxshape not implemented for derived class") class DataChunkIterator(AbstractDataChunkIterator): """ Custom iterator class used to iterate over chunks of data. This default implementation of AbstractDataChunkIterator accepts any iterable and assumes that we iterate over a single dimension of the data array (default: the first dimension). DataChunkIterator supports buffered read, i.e., multiple values from the input iterator can be combined to a single chunk. This is useful for buffered I/O operations, e.g., to improve performance by accumulating data in memory and writing larger blocks at once. """ __docval_init = ( {'name': 'data', 'type': None, 'doc': 'The data object used for iteration', 'default': None}, {'name': 'maxshape', 'type': tuple, 'doc': 'The maximum shape of the full data array. Use None to indicate unlimited dimensions', 'default': None}, {'name': 'dtype', 'type': np.dtype, 'doc': 'The Numpy data type for the array', 'default': None}, {'name': 'buffer_size', 'type': int, 'doc': 'Number of values to be buffered in a chunk', 'default': 1}, {'name': 'iter_axis', 'type': int, 'doc': 'The dimension to iterate over', 'default': 0} ) @docval(*__docval_init) def __init__(self, **kwargs): """Initialize the DataChunkIterator. If 'data' is an iterator and 'dtype' is not specified, then next is called on the iterator in order to determine the dtype of the data. """ # Get the user parameters self.data, self.__maxshape, self.__dtype, self.buffer_size, self.iter_axis = getargs('data', 'maxshape', 'dtype', 'buffer_size', 'iter_axis', kwargs) self.chunk_index = 0 # Create an iterator for the data if possible if isinstance(self.data, Iterable): if self.iter_axis != 0 and isinstance(self.data, (list, tuple)): warn('Iterating over an axis other than the first dimension of list or tuple data ' 'involves converting the data object to a numpy ndarray, which may incur a computational ' 'cost.') self.data = np.asarray(self.data) if isinstance(self.data, np.ndarray): # iterate over the given axis by adding a new view on data (iter only works on the first dim) self.__data_iter = iter(np.moveaxis(self.data, self.iter_axis, 0)) else: self.__data_iter = iter(self.data) else: self.__data_iter = None self.__next_chunk = DataChunk(None, None) self.__next_chunk_start = 0 self.__first_chunk_shape = None # Determine the shape of the data if possible if self.__maxshape is None: # If the self.data object identifies its shape, then use it if hasattr(self.data, "shape"): self.__maxshape = self.data.shape # Avoid the special case of scalar values by making them into a 1D numpy array if len(self.__maxshape) == 0: self.data = np.asarray([self.data, ]) self.__maxshape = self.data.shape self.__data_iter = iter(self.data) # Try to get an accurate idea of __maxshape for other Python data structures if possible. # Don't just call get_data_shape for a generator as that would potentially trigger loading of all the data elif isinstance(self.data, list) or isinstance(self.data, tuple): self.__maxshape = get_data_shape(self.data, strict_no_data_load=True) # If we have a data iterator and do not know the dtype, then read the first chunk if self.__data_iter is not None and self.__dtype is None: self._read_next_chunk() # Determine the type of the data if possible if self.__next_chunk.data is not None: self.__dtype = self.__next_chunk.data.dtype self.__first_chunk_shape = get_data_shape(self.__next_chunk.data) # This should be done as a last resort only if self.__first_chunk_shape is None and self.__maxshape is not None: self.__first_chunk_shape = tuple(1 if i is None else i for i in self.__maxshape) if self.__dtype is None: raise Exception('Data type could not be determined. Please specify dtype in DataChunkIterator init.') @classmethod @docval(*__docval_init) def from_iterable(cls, **kwargs): return cls(**kwargs) def __iter__(self): """Return the iterator object""" return self def _read_next_chunk(self): """Read a single chunk from self.__data_iter and store the results in self.__next_chunk :returns: self.__next_chunk, i.e., the DataChunk object describing the next chunk """ from h5py import Dataset as H5Dataset if isinstance(self.data, H5Dataset): start_index = self.chunk_index * self.buffer_size stop_index = start_index + self.buffer_size iter_data_bounds = self.data.shape[self.iter_axis] if start_index >= iter_data_bounds: self.__next_chunk = DataChunk(None, None) else: if stop_index > iter_data_bounds: stop_index = iter_data_bounds selection = [slice(None)] * len(self.maxshape) selection[self.iter_axis] = slice(start_index, stop_index) selection = tuple(selection) self.__next_chunk.data = self.data[selection] self.__next_chunk.selection = selection elif self.__data_iter is not None: # the pieces in the buffer - first dimension consists of individual calls to next iter_pieces = [] # offset of where data begins - shift the selection of where to place this chunk by this much curr_chunk_offset = 0 read_next_empty = False while len(iter_pieces) < self.buffer_size: try: dat = next(self.__data_iter) if dat is None and len(iter_pieces) == 0: # Skip forward in our chunk until we find data curr_chunk_offset += 1 elif dat is None and len(iter_pieces) > 0: # Stop iteration if we hit empty data while constructing our block # Buffer may not be full. read_next_empty = True break else: # Add pieces of data to our buffer iter_pieces.append(np.asarray(dat)) except StopIteration: break if len(iter_pieces) == 0: self.__next_chunk = DataChunk(None, None) # signal end of iteration else: # concatenate all the pieces into the chunk along the iteration axis piece_shape = list(get_data_shape(iter_pieces[0])) piece_shape.insert(self.iter_axis, 1) # insert the missing axis next_chunk_shape = piece_shape.copy() next_chunk_shape[self.iter_axis] *= len(iter_pieces) next_chunk_size = next_chunk_shape[self.iter_axis] # use the piece dtype because the actual dtype may not have been determined yet # NOTE: this could be problematic if a generator returns e.g. floats first and ints later self.__next_chunk.data = np.empty(next_chunk_shape, dtype=iter_pieces[0].dtype) self.__next_chunk.data = np.stack(iter_pieces, axis=self.iter_axis) selection = [slice(None)] * len(self.maxshape) selection[self.iter_axis] = slice(self.__next_chunk_start + curr_chunk_offset, self.__next_chunk_start + curr_chunk_offset + next_chunk_size) self.__next_chunk.selection = tuple(selection) # next chunk should start at self.__next_chunk.selection[self.iter_axis].stop # but if this chunk stopped because of reading empty data, then this should be adjusted by 1 self.__next_chunk_start = self.__next_chunk.selection[self.iter_axis].stop if read_next_empty: self.__next_chunk_start += 1 else: self.__next_chunk = DataChunk(None, None) self.chunk_index += 1 return self.__next_chunk def __next__(self): r"""Return the next data chunk or raise a StopIteration exception if all chunks have been retrieved. HINT: numpy.s\_ provides a convenient way to generate index tuples using standard array slicing. This is often useful to define the DataChunk.selection of the current chunk :returns: DataChunk object with the data and selection of the current chunk :rtype: DataChunk """ # If we have not already read the next chunk, then read it now if self.__next_chunk.data is None: self._read_next_chunk() # If we do not have any next chunk if self.__next_chunk.data is None: raise StopIteration # If this is the first time we see a chunk then remember the size of the first chunk if self.__first_chunk_shape is None: self.__first_chunk_shape = self.__next_chunk.data.shape # Keep the next chunk we need to return curr_chunk = DataChunk(self.__next_chunk.data, self.__next_chunk.selection) # Remove the data for the next chunk from our list since we are returning it here. # This is to allow the GarbageCollector to remmove the data when it goes out of scope and avoid # having 2 full chunks in memory if not necessary self.__next_chunk.data = None # Return the current next chunk return curr_chunk next = __next__ @docval(returns='Tuple with the recommended chunk shape or None if no particular shape is recommended.') def recommended_chunk_shape(self): """Recommend a chunk shape. To optimize iterative write the chunk should be aligned with the common shape of chunks returned by __next__ or if those chunks are too large, then a well-aligned subset of those chunks. This may also be any other value in case one wants to recommend chunk shapes to optimize read rather than write. The default implementation returns None, indicating no preferential chunking option.""" return None @docval(returns='Recommended initial shape for the full data. This should be the shape of the full dataset' + 'if known beforehand or alternatively the minimum shape of the dataset. Return None if no ' + 'recommendation is available') def recommended_data_shape(self): """Recommend an initial shape of the data. This is useful when progressively writing data and we want to recommend an initial size for the dataset""" if self.maxshape is not None: if np.all([i is not None for i in self.maxshape]): return self.maxshape return self.__first_chunk_shape @property def maxshape(self): """ Get a shape tuple describing the maximum shape of the array described by this DataChunkIterator. If an iterator is provided and no data has been read yet, then the first chunk will be read (i.e., next will be called on the iterator) in order to determine the maxshape. :return: Shape tuple. None is used for dimenwions where the maximum shape is not known or unlimited. """ if self.__maxshape is None: # If no data has been read from the iterator yet, read the first chunk and use it to determine the maxshape if self.__data_iter is not None and self.__next_chunk.data is None: self._read_next_chunk() # Determine maxshape from self.__next_chunk if self.__next_chunk.data is None: return None data_shape = get_data_shape(self.__next_chunk.data) self.__maxshape = list(data_shape) try: # Size of self.__next_chunk.data along self.iter_axis is not accurate for maxshape because it is just a # chunk. So try to set maxshape along the dimension self.iter_axis based on the shape of self.data if # possible. Otherwise, use None to represent an unlimited size if hasattr(self.data, '__len__') and self.iter_axis == 0: # special case of 1-D array self.__maxshape[0] = len(self.data) else: self.__maxshape[self.iter_axis] = self.data.shape[self.iter_axis] except AttributeError: # from self.data.shape self.__maxshape[self.iter_axis] = None self.__maxshape = tuple(self.__maxshape) return self.__maxshape @property def dtype(self): """ Get the value data type :return: np.dtype object describing the datatype """ return self.__dtype class DataChunk: """ Class used to describe a data chunk. Used in DataChunkIterator. """ @docval({'name': 'data', 'type': np.ndarray, 'doc': 'Numpy array with the data value(s) of the chunk', 'default': None}, {'name': 'selection', 'type': None, 'doc': 'Numpy index tuple describing the location of the chunk', 'default': None}) def __init__(self, **kwargs): self.data, self.selection = getargs('data', 'selection', kwargs) def __len__(self): """Get the number of values in the data chunk""" if self.data is not None: return len(self.data) else: return 0 def __getattr__(self, attr): """Delegate retrival of attributes to the data in self.data""" return getattr(self.data, attr) def __copy__(self): newobj = DataChunk(data=self.data, selection=self.selection) return newobj def __deepcopy__(self, memo): result = DataChunk(data=copy.deepcopy(self.data), selection=copy.deepcopy(self.selection)) memo[id(self)] = result return result def astype(self, dtype): """Get a new DataChunk with the self.data converted to the given type""" return DataChunk(data=self.data.astype(dtype), selection=self.selection) @property def dtype(self): """ Data type of the values in the chunk :returns: np.dtype of the values in the DataChunk """ return self.data.dtype def assertEqualShape(data1, data2, axes1=None, axes2=None, name1=None, name2=None, ignore_undetermined=True): """ Ensure that the shape of data1 and data2 match along the given dimensions :param data1: The first input array :type data1: List, Tuple, np.ndarray, DataChunkIterator etc. :param data2: The second input array :type data2: List, Tuple, np.ndarray, DataChunkIterator etc. :param name1: Optional string with the name of data1 :param name2: Optional string with the name of data2 :param axes1: The dimensions of data1 that should be matched to the dimensions of data2. Set to None to compare all axes in order. :type axes1: int, Tuple of ints, List of ints, or None :param axes2: The dimensions of data2 that should be matched to the dimensions of data1. Must have the same length as axes1. Set to None to compare all axes in order. :type axes1: int, Tuple of ints, List of ints, or None :param ignore_undetermined: Boolean indicating whether non-matching unlimited dimensions should be ignored, i.e., if two dimension don't match because we can't determine the shape of either one, then should we ignore that case or treat it as no match :return: Bool indicating whether the check passed and a string with a message about the matching process """ # Create the base return object response = ShapeValidatorResult() # Determine the shape of the datasets response.shape1 = get_data_shape(data1) response.shape2 = get_data_shape(data2) # Determine the number of dimensions of the datasets num_dims_1 = len(response.shape1) if response.shape1 is not None else None num_dims_2 = len(response.shape2) if response.shape2 is not None else None # Determine the string names of the datasets n1 = name1 if name1 is not None else ("data1 at " + str(hex(id(data1)))) n2 = name2 if name2 is not None else ("data2 at " + str(hex(id(data2)))) # Determine the axes we should compare response.axes1 = list(range(num_dims_1)) if axes1 is None else ([axes1] if isinstance(axes1, int) else axes1) response.axes2 = list(range(num_dims_2)) if axes2 is None else ([axes2] if isinstance(axes2, int) else axes2) # Validate the array shape # 1) Check the number of dimensions of the arrays if (response.axes1 is None and response.axes2 is None) and num_dims_1 != num_dims_2: response.result = False response.error = 'NUM_DIMS_ERROR' response.message = response.SHAPE_ERROR[response.error] response.message += " %s is %sD and %s is %sD" % (n1, num_dims_1, n2, num_dims_2) # 2) Check that we have the same number of dimensions to compare on both arrays elif len(response.axes1) != len(response.axes2): response.result = False response.error = 'NUM_AXES_ERROR' response.message = response.SHAPE_ERROR[response.error] response.message += " Cannot compare axes %s with %s" % (str(response.axes1), str(response.axes2)) # 3) Check that the datasets have sufficient numner of dimensions elif np.max(response.axes1) >= num_dims_1 or np.max(response.axes2) >= num_dims_2: response.result = False response.error = 'AXIS_OUT_OF_BOUNDS' response.message = response.SHAPE_ERROR[response.error] if np.max(response.axes1) >= num_dims_1: response.message += "Insufficient number of dimensions for %s -- Expected %i found %i" % \ (n1, np.max(response.axes1) + 1, num_dims_1) elif np.max(response.axes2) >= num_dims_2: response.message += "Insufficient number of dimensions for %s -- Expected %i found %i" % \ (n2, np.max(response.axes2) + 1, num_dims_2) # 4) Compare the length of the dimensions we should validate else: unmatched = [] ignored = [] for ax in zip(response.axes1, response.axes2): if response.shape1[ax[0]] != response.shape2[ax[1]]: if ignore_undetermined and (response.shape1[ax[0]] is None or response.shape2[ax[1]] is None): ignored.append(ax) else: unmatched.append(ax) response.unmatched = unmatched response.ignored = ignored # Check if everything checked out if len(response.unmatched) == 0: response.result = True response.error = None response.message = response.SHAPE_ERROR[response.error] if len(response.ignored) > 0: response.message += " Ignored undetermined axes %s" % str(response.ignored) else: response.result = False response.error = 'AXIS_LEN_ERROR' response.message = response.SHAPE_ERROR[response.error] response.message += "Axes %s with size %s of %s did not match dimensions %s with sizes %s of %s." % \ (str([un[0] for un in response.unmatched]), str([response.shape1[un[0]] for un in response.unmatched]), n1, str([un[1] for un in response.unmatched]), str([response.shape2[un[1]] for un in response.unmatched]), n2) if len(response.ignored) > 0: response.message += " Ignored undetermined axes %s" % str(response.ignored) return response class ShapeValidatorResult: """Class for storing results from validating the shape of multi-dimensional arrays. This class is used to store results generated by ShapeValidator :ivar result: Boolean indicating whether results matched or not :type result: bool :ivar message: Message indicating the result of the matching procedure :type messaage: str, None """ SHAPE_ERROR = {None: 'All required axes matched', 'NUM_DIMS_ERROR': 'Unequal number of dimensions.', 'NUM_AXES_ERROR': "Unequal number of axes for comparison.", 'AXIS_OUT_OF_BOUNDS': "Axis index for comparison out of bounds.", 'AXIS_LEN_ERROR': "Unequal length of axes."} """ Dict where the Keys are the type of errors that may have occurred during shape comparison and the values are strings with default error messages for the type. """ @docval({'name': 'result', 'type': bool, 'doc': 'Result of the shape validation', 'default': False}, {'name': 'message', 'type': str, 'doc': 'Message describing the result of the shape validation', 'default': None}, {'name': 'ignored', 'type': tuple, 'doc': 'Axes that have been ignored in the validaton process', 'default': tuple(), 'shape': (None,)}, {'name': 'unmatched', 'type': tuple, 'doc': 'List of axes that did not match during shape validation', 'default': tuple(), 'shape': (None,)}, {'name': 'error', 'type': str, 'doc': 'Error that may have occurred. One of ERROR_TYPE', 'default': None}, {'name': 'shape1', 'type': tuple, 'doc': 'Shape of the first array for comparison', 'default': tuple(), 'shape': (None,)}, {'name': 'shape2', 'type': tuple, 'doc': 'Shape of the second array for comparison', 'default': tuple(), 'shape': (None,)}, {'name': 'axes1', 'type': tuple, 'doc': 'Axes for the first array that should match', 'default': tuple(), 'shape': (None,)}, {'name': 'axes2', 'type': tuple, 'doc': 'Axes for the second array that should match', 'default': tuple(), 'shape': (None,)}, ) def __init__(self, **kwargs): self.result, self.message, self.ignored, self.unmatched, \ self.error, self.shape1, self.shape2, self.axes1, self.axes2 = getargs( 'result', 'message', 'ignored', 'unmatched', 'error', 'shape1', 'shape2', 'axes1', 'axes2', kwargs) def __setattr__(self, key, value): """ Overwrite to ensure that, e.g., error_message is not set to an illegal value. """ if key == 'error': if value not in self.SHAPE_ERROR.keys(): raise ValueError("Illegal error type. Error must be one of ShapeValidatorResult.SHAPE_ERROR: %s" % str(self.SHAPE_ERROR)) else: super().__setattr__(key, value) elif key in ['shape1', 'shape2', 'axes1', 'axes2', 'ignored', 'unmatched']: # Make sure we sore tuples super().__setattr__(key, tuple(value)) else: super().__setattr__(key, value) def __getattr__(self, item): """ Overwrite to allow dynamic retrival of the default message """ if item == 'default_message': return self.SHAPE_ERROR[self.error] return self.__getattribute__(item) @docval_macro('data') class DataIO: """ Base class for wrapping data arrays for I/O. Derived classes of DataIO are typically used to pass dataset-specific I/O parameters to the particular HDMFIO backend. """ @docval({'name': 'data', 'type': 'array_data', 'doc': 'the data to be written', 'default': None}) def __init__(self, **kwargs): data = popargs('data', kwargs) self.__data = data def get_io_params(self): """ Returns a dict with the I/O parameters specified in this DataIO. """ return dict() @property def data(self): """Get the wrapped data object""" return self.__data @data.setter def data(self, val): """Set the wrapped data object""" if self.__data is not None: raise ValueError("cannot overwrite 'data' on DataIO") self.__data = val def __copy__(self): """ Define a custom copy method for shallow copy.. This is needed due to delegation of __getattr__ to the data to ensure proper copy. :return: Shallow copy of self, ie., a new instance of DataIO wrapping the same self.data object """ newobj = DataIO(data=self.data) return newobj def append(self, arg): self.__data = append_data(self.__data, arg) def extend(self, arg): self.__data = extend_data(self.__data, arg) def __deepcopy__(self, memo): """ Define a custom copy method for deep copy. This is needed due to delegation of __getattr__ to the data to ensure proper copy. :param memo: :return: Deep copy of self, i.e., a new instance of DataIO wrapping a deepcopy of the self.data object. """ result = DataIO(data=copy.deepcopy(self.__data)) memo[id(self)] = result return result def __len__(self): """Number of values in self.data""" if not self.valid: raise InvalidDataIOError("Cannot get length of data. Data is not valid.") return len(self.data) def __bool__(self): if self.valid: if isinstance(self.data, AbstractDataChunkIterator): return True return len(self) > 0 return False def __getattr__(self, attr): """Delegate attribute lookup to data object""" if attr == '__array_struct__' and not self.valid: # np.array() checks __array__ or __array_struct__ attribute dep. on numpy version raise InvalidDataIOError("Cannot convert data to array. Data is not valid.") if not self.valid: raise InvalidDataIOError("Cannot get attribute '%s' of data. Data is not valid." % attr) return getattr(self.data, attr) def __getitem__(self, item): """Delegate slicing to the data object""" if not self.valid: raise InvalidDataIOError("Cannot get item from data. Data is not valid.") return self.data[item] def __array__(self): """ Support conversion of DataIO.data to a numpy array. This function is provided to improve transparent interoperability of DataIO with numpy. :return: An array instance of self.data """ if not self.valid: raise InvalidDataIOError("Cannot convert data to array. Data is not valid.") if hasattr(self.data, '__array__'): return self.data.__array__() elif isinstance(self.data, DataChunkIterator): raise NotImplementedError("Conversion of DataChunkIterator to array not supported") else: # NOTE this may result in a copy of the array return np.asarray(self.data) def __next__(self): """Delegate iteration interface to data object""" if not self.valid: raise InvalidDataIOError("Cannot iterate on data. Data is not valid.") return self.data.__next__() def __iter__(self): """Delegate iteration interface to the data object""" if not self.valid: raise InvalidDataIOError("Cannot iterate on data. Data is not valid.") return self.data.__iter__() @property def valid(self): """bool indicating if the data object is valid""" return self.data is not None class InvalidDataIOError(Exception): pass ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/monitor.py0000644000655200065520000000434000000000000016330 0ustar00circlecicirclecifrom abc import ABCMeta, abstractmethod from .data_utils import AbstractDataChunkIterator, DataChunkIterator, DataChunk from .utils import docval, getargs, call_docval_func class NotYetExhausted(Exception): pass class DataChunkProcessor(AbstractDataChunkIterator, metaclass=ABCMeta): @docval({'name': 'data', 'type': DataChunkIterator, 'doc': 'the DataChunkIterator to analyze'}) def __init__(self, **kwargs): """Initialize the DataChunkIterator""" # Get the user parameters self.__dci = getargs('data', kwargs) def __next__(self): try: dc = self.__dci.__next__() except StopIteration as e: self.__done = True raise e self.process_data_chunk(dc) return dc def __iter__(self): return iter(self.__dci) def recommended_chunk_shape(self): return self.__dci.recommended_chunk_shape() def recommended_data_shape(self): return self.__dci.recommended_data_shape() def get_final_result(self, **kwargs): ''' Return the result of processing data fed by this DataChunkIterator ''' if not self.__done: raise NotYetExhausted() return self.compute_final_result() @abstractmethod @docval({'name': 'data_chunk', 'type': DataChunk, 'doc': 'a chunk to process'}) def process_data_chunk(self, **kwargs): ''' This method should take in a DataChunk, and process it. ''' pass @abstractmethod @docval(returns='the result of processing this stream') def compute_final_result(self, **kwargs): ''' Return the result of processing this stream Should raise NotYetExhaused exception ''' pass class NumSampleCounter(DataChunkProcessor): def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) self.__sample_count = 0 @docval({'name': 'data_chunk', 'type': DataChunk, 'doc': 'a chunk to process'}) def process_data_chunk(self, **kwargs): dc = getargs('data_chunk', kwargs) self.__sample_count += len(dc) @docval(returns='the result of processing this stream') def compute_final_result(self, **kwargs): return self.__sample_count ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/query.py0000644000655200065520000001426100000000000016011 0ustar00circlecicirclecifrom abc import ABCMeta, abstractmethod import numpy as np from .array import Array from .utils import ExtenderMeta, docval_macro, docval, getargs class Query(metaclass=ExtenderMeta): __operations__ = ( '__lt__', '__gt__', '__le__', '__ge__', '__eq__', '__ne__', ) @classmethod def __build_operation(cls, op): def __func(self, arg): return cls(self, op, arg) @ExtenderMeta.pre_init def __make_operators(cls, name, bases, classdict): if not isinstance(cls.__operations__, tuple): raise TypeError("'__operations__' must be of type tuple") # add any new operations if len(bases) and 'Query' in globals() and issubclass(bases[-1], Query) \ and bases[-1].__operations__ is not cls.__operations__: new_operations = list(cls.__operations__) new_operations[0:0] = bases[-1].__operations__ cls.__operations__ = tuple(new_operations) for op in cls.__operations__: if not hasattr(cls, op): setattr(cls, op, cls.__build_operation(op)) def __init__(self, obj, op, arg): self.obj = obj self.op = op self.arg = arg self.collapsed = None self.expanded = None @docval({'name': 'expand', 'type': bool, 'help': 'whether or not to expand result', 'default': True}) def evaluate(self, **kwargs): expand = getargs('expand', kwargs) if expand: if self.expanded is None: self.expanded = self.__evalhelper() return self.expanded else: if self.collapsed is None: self.collapsed = self.__collapse(self.__evalhelper()) return self.collapsed def __evalhelper(self): obj = self.obj arg = self.arg if isinstance(obj, Query): obj = obj.evaluate() elif isinstance(obj, HDMFDataset): obj = obj.dataset if isinstance(arg, Query): arg = self.arg.evaluate() return getattr(obj, self.op)(self.arg) def __collapse(self, result): if isinstance(result, slice): return (result.start, result.stop) elif isinstance(result, list): ret = list() for idx in result: if isinstance(idx, slice) and (idx.step is None or idx.step == 1): ret.append((idx.start, idx.stop)) else: ret.append(idx) return ret else: return result def __and__(self, other): return NotImplemented def __or__(self, other): return NotImplemented def __xor__(self, other): return NotImplemented def __contains__(self, other): return NotImplemented @docval_macro('array_data') class HDMFDataset(metaclass=ExtenderMeta): __operations__ = ( '__lt__', '__gt__', '__le__', '__ge__', '__eq__', '__ne__', ) @classmethod def __build_operation(cls, op): def __func(self, arg): return Query(self, op, arg) setattr(__func, '__name__', op) return __func @ExtenderMeta.pre_init def __make_operators(cls, name, bases, classdict): if not isinstance(cls.__operations__, tuple): raise TypeError("'__operations__' must be of type tuple") # add any new operations if len(bases) and 'Query' in globals() and issubclass(bases[-1], Query) \ and bases[-1].__operations__ is not cls.__operations__: new_operations = list(cls.__operations__) new_operations[0:0] = bases[-1].__operations__ cls.__operations__ = tuple(new_operations) for op in cls.__operations__: setattr(cls, op, cls.__build_operation(op)) def __evaluate_key(self, key): if isinstance(key, tuple) and len(key) == 0: return key if isinstance(key, (tuple, list, np.ndarray)): return list(map(self.__evaluate_key, key)) else: if isinstance(key, Query): return key.evaluate() return key def __getitem__(self, key): idx = self.__evaluate_key(key) return self.dataset[idx] @docval({'name': 'dataset', 'type': ('array_data', Array), 'doc': 'the HDF5 file lazily evaluate'}) def __init__(self, **kwargs): super().__init__() self.__dataset = getargs('dataset', kwargs) @property def dataset(self): return self.__dataset @property def dtype(self): return self.__dataset.dtype def __len__(self): return len(self.__dataset) def __iter__(self): return iter(self.dataset) def __next__(self): return next(self.dataset) def next(self): return self.dataset.next() class ReferenceResolver(metaclass=ABCMeta): """ A base class for classes that resolve references """ @classmethod @abstractmethod def get_inverse_class(cls): """ Return the class the represents the ReferenceResolver that resolves refernces to the opposite type. BuilderResolver.get_inverse_class should return a class that subclasses ContainerResolver. ContainerResolver.get_inverse_class should return a class that subclasses BuilderResolver. """ pass @abstractmethod def invert(self): """ Return an object that defers reference resolution but in the opposite direction. """ pass class BuilderResolver(ReferenceResolver): """ A reference resolver that resolves references to Builders Subclasses should implement the invert method and the get_inverse_class classmethod BuilderResolver.get_inverse_class should return a class that subclasses ContainerResolver. """ pass class ContainerResolver(ReferenceResolver): """ A reference resolver that resolves references to Containers Subclasses should implement the invert method and the get_inverse_class classmethod ContainerResolver.get_inverse_class should return a class that subclasses BuilderResolver. """ pass ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/region.py0000644000655200065520000000522600000000000016130 0ustar00circlecicirclecifrom abc import ABCMeta, abstractmethod from operator import itemgetter from .container import Data, DataRegion from .utils import docval, getargs class RegionSlicer(DataRegion, metaclass=ABCMeta): ''' A abstract base class to control getting using a region Subclasses must implement `__getitem__` and `__len__` ''' @docval({'name': 'target', 'type': None, 'doc': 'the target to slice'}, {'name': 'slice', 'type': None, 'doc': 'the region to slice'}) def __init__(self, **kwargs): self.__target = getargs('target', kwargs) self.__slice = getargs('slice', kwargs) @property def data(self): """The target data. Same as self.target""" return self.target @property def region(self): """The selected region. Same as self.slice""" return self.slice @property def target(self): """The target data""" return self.__target @property def slice(self): """The selected slice""" return self.__slice @property @abstractmethod def __getitem__(self, idx): """Must be implemented by subclasses""" pass @property @abstractmethod def __len__(self): """Must be implemented by subclasses""" pass class ListSlicer(RegionSlicer): """Implementation of RegionSlicer for slicing Lists and Data""" @docval({'name': 'dataset', 'type': (list, tuple, Data), 'doc': 'the dataset to slice'}, {'name': 'region', 'type': (list, tuple, slice), 'doc': 'the region reference to use to slice'}) def __init__(self, **kwargs): self.__dataset, self.__region = getargs('dataset', 'region', kwargs) super().__init__(self.__dataset, self.__region) if isinstance(self.__region, slice): self.__getter = itemgetter(self.__region) self.__len = len(range(*self.__region.indices(len(self.__dataset)))) else: self.__getter = itemgetter(*self.__region) self.__len = len(self.__region) def __read_region(self): """ Internal helper function used to define self._read """ if not hasattr(self, '_read'): self._read = self.__getter(self.__dataset) del self.__getter def __getitem__(self, idx): """ Get data values from selected data """ self.__read_region() getter = None if isinstance(idx, (list, tuple)): getter = itemgetter(*idx) else: getter = itemgetter(idx) return getter(self._read) def __len__(self): """Number of values in the slice/region""" return self.__len ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1627603655.180627 hdmf-3.1.1/src/hdmf/spec/0000755000655200065520000000000000000000000015220 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/spec/__init__.py0000644000655200065520000000045600000000000017336 0ustar00circlecicirclecifrom .catalog import SpecCatalog from .namespace import NamespaceCatalog, SpecNamespace, SpecReader from .spec import (AttributeSpec, DatasetSpec, DtypeHelper, DtypeSpec, GroupSpec, LinkSpec, NAME_WILDCARD, RefSpec, Spec) from .write import NamespaceBuilder, SpecWriter, export_spec ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/spec/catalog.py0000644000655200065520000002207700000000000017214 0ustar00circlecicircleciimport copy from collections import OrderedDict from .spec import BaseStorageSpec, GroupSpec from ..utils import docval, getargs class SpecCatalog: def __init__(self): ''' Create a new catalog for storing specifications ** Private Instance Variables ** :ivar __specs: Dict with the specification of each registered type :ivar __parent_types: Dict with parent types for each registered type :ivar __spec_source_files: Dict with the path to the source files (if available) for each registered type :ivar __hierarchy: Dict describing the hierarchy for each registered type. NOTE: Always use SpecCatalog.get_hierarchy(...) to retrieve the hierarchy as this dictionary is used like a cache, i.e., to avoid repeated calcuation of the hierarchy but the contents are computed on first request by SpecCatalog.get_hierarchy(...) ''' self.__specs = OrderedDict() self.__parent_types = dict() self.__hierarchy = dict() self.__spec_source_files = dict() @docval({'name': 'spec', 'type': BaseStorageSpec, 'doc': 'a Spec object'}, {'name': 'source_file', 'type': str, 'doc': 'path to the source file from which the spec was loaded', 'default': None}) def register_spec(self, **kwargs): ''' Associate a specified object type with a specification ''' spec, source_file = getargs('spec', 'source_file', kwargs) ndt = spec.data_type_inc ndt_def = spec.data_type_def if ndt_def is None: raise ValueError('cannot register spec that has no data_type_def') if ndt_def != ndt: self.__parent_types[ndt_def] = ndt type_name = ndt_def if ndt_def is not None else ndt if type_name in self.__specs: if self.__specs[type_name] != spec or self.__spec_source_files[type_name] != source_file: raise ValueError("'%s' - cannot overwrite existing specification" % type_name) self.__specs[type_name] = spec self.__spec_source_files[type_name] = source_file @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to get the Spec for'}, returns="the specification for writing the given object type to HDF5 ", rtype='Spec') def get_spec(self, **kwargs): ''' Get the Spec object for the given type ''' data_type = getargs('data_type', kwargs) return self.__specs.get(data_type, None) @docval(rtype=tuple) def get_registered_types(self, **kwargs): ''' Return all registered specifications ''' # kwargs is not used here but is used by docval return tuple(self.__specs.keys()) @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type of the spec to get the source file for'}, returns="the path to source specification file from which the spec was originally loaded or None ", rtype='str') def get_spec_source_file(self, **kwargs): ''' Return the path to the source file from which the spec for the given type was loaded from. None is returned if no file path is available for the spec. Note: The spec in the file may not be identical to the object in case the spec is modified after load. ''' data_type = getargs('data_type', kwargs) return self.__spec_source_files.get(data_type, None) @docval({'name': 'spec', 'type': BaseStorageSpec, 'doc': 'the Spec object to register'}, {'name': 'source_file', 'type': str, 'doc': 'path to the source file from which the spec was loaded', 'default': None}, rtype=tuple, returns='the types that were registered with this spec') def auto_register(self, **kwargs): ''' Register this specification and all sub-specification using data_type as object type name ''' spec, source_file = getargs('spec', 'source_file', kwargs) ndt = spec.data_type_def ret = list() if ndt is not None: self.register_spec(spec, source_file) ret.append(ndt) if isinstance(spec, GroupSpec): for dataset_spec in spec.datasets: dset_ndt = dataset_spec.data_type_def if dset_ndt is not None and not spec.is_inherited_type(dataset_spec): ret.append(dset_ndt) self.register_spec(dataset_spec, source_file) for group_spec in spec.groups: ret.extend(self.auto_register(group_spec, source_file)) return tuple(ret) @docval({'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the hierarchy of'}, returns="Tuple of strings with the names of the types the given data_type inherits from.", rtype=tuple) def get_hierarchy(self, **kwargs): """ For a given type get the type inheritance hierarchy for that type. E.g., if we have a type MyContainer that inherits from BaseContainer then the result will be a tuple with the strings ('MyContainer', 'BaseContainer') """ data_type = getargs('data_type', kwargs) if isinstance(data_type, type): data_type = data_type.__name__ ret = self.__hierarchy.get(data_type) if ret is None: hierarchy = list() parent = data_type while parent is not None: hierarchy.append(parent) parent = self.__parent_types.get(parent) # store the computed hierarchy for data_type and all types in between it and # the top of the hierarchy tmp_hier = tuple(hierarchy) ret = tmp_hier while len(tmp_hier) > 0: self.__hierarchy[tmp_hier[0]] = tmp_hier tmp_hier = tmp_hier[1:] return tuple(ret) @docval(returns="Hierarchically nested OrderedDict with the hierarchy of all the types", rtype=OrderedDict) def get_full_hierarchy(self): """ Get the complete hierarchy of all types. The function attempts to sort types by name using standard Python sorted. """ # Get the list of all types registered_types = self.get_registered_types() type_hierarchy = OrderedDict() # Internal helper function to recurisvely construct the hierarchy of types def get_type_hierarchy(data_type, spec_catalog): dtype_hier = OrderedDict() for dtype in sorted(self.get_subtypes(data_type=data_type, recursive=False)): dtype_hier[dtype] = get_type_hierarchy(dtype, spec_catalog) return dtype_hier # Compute the type hierarchy for rt in sorted(registered_types): rt_spec = self.get_spec(rt) if isinstance(rt_spec, BaseStorageSpec): # Only BaseStorageSpec have data_type_inc/def keys if rt_spec.get(rt_spec.inc_key(), None) is None: type_hierarchy[rt] = get_type_hierarchy(rt, self) return type_hierarchy @docval({'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the subtypes for'}, {'name': 'recursive', 'type': bool, 'doc': 'recursively get all subtypes. Set to False to only get the direct subtypes', 'default': True}, returns="Tuple of strings with the names of all types of the given data_type.", rtype=tuple) def get_subtypes(self, **kwargs): """ For a given data type recursively find all the subtypes that inherit from it. E.g., assume we have the following inheritance hierarchy:: -BaseContainer--+-->AContainer--->ADContainer | +-->BContainer In this case, the subtypes of BaseContainer would be (AContainer, ADContainer, BContainer), the subtypes of AContainer would be (ADContainer), and the subtypes of BContainer would be empty (). """ data_type, recursive = getargs('data_type', 'recursive', kwargs) curr_spec = self.get_spec(data_type) if isinstance(curr_spec, BaseStorageSpec): # Only BaseStorageSpec have data_type_inc/def keys subtypes = [] spec_inc_key = curr_spec.inc_key() spec_def_key = curr_spec.def_key() for rt in self.get_registered_types(): rt_spec = self.get_spec(rt) if rt_spec.get(spec_inc_key, None) == data_type and rt_spec.get(spec_def_key, None) != data_type: subtypes.append(rt) if recursive: subtypes += self.get_subtypes(rt) return tuple(set(subtypes)) # Convert to a set to make sure we don't have any duplicates else: return () def __copy__(self): ret = SpecCatalog() ret.__specs = copy.copy(self.__specs) return ret def __deepcopy__(self, memo): ret = SpecCatalog() ret.__specs = copy.deepcopy(self.__specs, memo) return ret ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/spec/namespace.py0000644000655200065520000006016600000000000017537 0ustar00circlecicircleciimport os.path import ruamel.yaml as yaml import string from abc import ABCMeta, abstractmethod from collections import OrderedDict from copy import copy from datetime import datetime from warnings import warn from .catalog import SpecCatalog from .spec import DatasetSpec, GroupSpec from ..utils import docval, getargs, popargs, get_docval, call_docval_func _namespace_args = [ {'name': 'doc', 'type': str, 'doc': 'a description about what this namespace represents'}, {'name': 'name', 'type': str, 'doc': 'the name of this namespace'}, {'name': 'schema', 'type': list, 'doc': 'location of schema specification files or other Namespaces'}, {'name': 'full_name', 'type': str, 'doc': 'extended full name of this namespace', 'default': None}, {'name': 'version', 'type': (str, tuple, list), 'doc': 'Version number of the namespace', 'default': None}, {'name': 'date', 'type': (datetime, str), 'doc': "Date last modified or released. Formatting is %Y-%m-%d %H:%M:%S, e.g, 2017-04-25 17:14:13", 'default': None}, {'name': 'author', 'type': (str, list), 'doc': 'Author or list of authors.', 'default': None}, {'name': 'contact', 'type': (str, list), 'doc': 'List of emails. Ordering should be the same as for author', 'default': None}, {'name': 'catalog', 'type': SpecCatalog, 'doc': 'The SpecCatalog object for this SpecNamespace', 'default': None} ] class SpecNamespace(dict): """ A namespace for specifications """ __types_key = 'data_types' UNVERSIONED = None # value representing missing version @docval(*_namespace_args) def __init__(self, **kwargs): doc, full_name, name, version, date, author, contact, schema, catalog = \ popargs('doc', 'full_name', 'name', 'version', 'date', 'author', 'contact', 'schema', 'catalog', kwargs) super().__init__() self['doc'] = doc self['schema'] = schema if any(c in string.whitespace for c in name): raise ValueError("'name' must not contain any whitespace") self['name'] = name if full_name is not None: self['full_name'] = full_name if version == str(SpecNamespace.UNVERSIONED): # the unversioned version may be written to file as a string and read from file as a string warn("Loaded namespace '%s' is unversioned. Please notify the extension author." % name) version = SpecNamespace.UNVERSIONED if version is None: # version is required on write -- see YAMLSpecWriter.write_namespace -- but can be None on read in order to # be able to read older files with extensions that are missing the version key. warn(("Loaded namespace '%s' is missing the required key 'version'. Version will be set to '%s'. " "Please notify the extension author.") % (name, SpecNamespace.UNVERSIONED)) version = SpecNamespace.UNVERSIONED self['version'] = version if date is not None: self['date'] = date if author is not None: self['author'] = author if contact is not None: self['contact'] = contact self.__catalog = catalog if catalog is not None else SpecCatalog() @classmethod def types_key(cls): ''' Get the key used for specifying types to include from a file or namespace Override this method to use a different name for 'data_types' ''' return cls.__types_key @property def full_name(self): """String with full name or None""" return self.get('full_name', None) @property def contact(self): """String or list of strings with the contacts or None""" return self.get('contact', None) @property def author(self): """String or list of strings with the authors or None""" return self.get('author', None) @property def version(self): """ String, list, or tuple with the version or SpecNamespace.UNVERSIONED if the version is missing or empty """ return self.get('version', None) or SpecNamespace.UNVERSIONED @property def date(self): """Date last modified or released. :return: datetime object, string, or None""" return self.get('date', None) @property def name(self): """String with short name or None""" return self.get('name', None) @property def doc(self): return self['doc'] @property def schema(self): return self['schema'] def get_source_files(self): """ Get the list of names of the source files included the schema of the namespace """ return [item['source'] for item in self.schema if 'source' in item] @docval({'name': 'sourcefile', 'type': str, 'doc': 'Name of the source file'}, returns='Dict with the source file documentation', rtype=dict) def get_source_description(self, sourcefile): """ Get the description of a source file as described in the namespace. The result is a dict which contains the 'source' and optionally 'title', 'doc' and 'data_types' imported from the source file """ for item in self.schema: if item.get('source', None) == sourcefile: return item @property def catalog(self): """The SpecCatalog containing all the Specs""" return self.__catalog @docval({'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the spec for'}) def get_spec(self, **kwargs): """Get the Spec object for the given data type""" data_type = getargs('data_type', kwargs) spec = self.__catalog.get_spec(data_type) if spec is None: raise ValueError("No specification for '%s' in namespace '%s'" % (data_type, self.name)) return spec @docval(returns="the a tuple of the available data types", rtype=tuple) def get_registered_types(self, **kwargs): """Get the available types in this namespace""" return self.__catalog.get_registered_types() @docval({'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the hierarchy of'}, returns="a tuple with the type hierarchy", rtype=tuple) def get_hierarchy(self, **kwargs): ''' Get the extension hierarchy for the given data_type in this namespace''' data_type = getargs('data_type', kwargs) return self.__catalog.get_hierarchy(data_type) @classmethod def build_namespace(cls, **spec_dict): kwargs = copy(spec_dict) try: args = [kwargs.pop(x['name']) for x in get_docval(cls.__init__) if 'default' not in x] except KeyError as e: raise KeyError("'%s' not found in %s" % (e.args[0], str(spec_dict))) return cls(*args, **kwargs) class SpecReader(metaclass=ABCMeta): @docval({'name': 'source', 'type': str, 'doc': 'the source from which this reader reads from'}) def __init__(self, **kwargs): self.__source = getargs('source', kwargs) @property def source(self): return self.__source @abstractmethod def read_spec(self): pass @abstractmethod def read_namespace(self): pass class YAMLSpecReader(SpecReader): @docval({'name': 'indir', 'type': str, 'doc': 'the path spec files are relative to', 'default': '.'}) def __init__(self, **kwargs): super_kwargs = {'source': kwargs['indir']} call_docval_func(super().__init__, super_kwargs) def read_namespace(self, namespace_path): namespaces = None with open(namespace_path, 'r') as stream: yaml_obj = yaml.YAML(typ='safe', pure=True) d = yaml_obj.load(stream) namespaces = d.get('namespaces') if namespaces is None: raise ValueError("no 'namespaces' found in %s" % namespace_path) return namespaces def read_spec(self, spec_path): specs = None with open(self.__get_spec_path(spec_path), 'r') as stream: yaml_obj = yaml.YAML(typ='safe', pure=True) specs = yaml_obj.load(stream) if not ('datasets' in specs or 'groups' in specs): raise ValueError("no 'groups' or 'datasets' found in %s" % spec_path) return specs def __get_spec_path(self, spec_path): if os.path.isabs(spec_path): return spec_path return os.path.join(self.source, spec_path) class NamespaceCatalog: @docval({'name': 'group_spec_cls', 'type': type, 'doc': 'the class to use for group specifications', 'default': GroupSpec}, {'name': 'dataset_spec_cls', 'type': type, 'doc': 'the class to use for dataset specifications', 'default': DatasetSpec}, {'name': 'spec_namespace_cls', 'type': type, 'doc': 'the class to use for specification namespaces', 'default': SpecNamespace}) def __init__(self, **kwargs): """Create a catalog for storing multiple Namespaces""" self.__namespaces = OrderedDict() self.__dataset_spec_cls = getargs('dataset_spec_cls', kwargs) self.__group_spec_cls = getargs('group_spec_cls', kwargs) self.__spec_namespace_cls = getargs('spec_namespace_cls', kwargs) # keep track of all spec objects ever loaded, so we don't have # multiple object instances of a spec self.__loaded_specs = dict() self.__included_specs = dict() self.__included_sources = dict() self._loaded_specs = self.__loaded_specs def __copy__(self): ret = NamespaceCatalog(self.__group_spec_cls, self.__dataset_spec_cls, self.__spec_namespace_cls) ret.__namespaces = copy(self.__namespaces) ret.__loaded_specs = copy(self.__loaded_specs) ret.__included_specs = copy(self.__included_specs) ret.__included_sources = copy(self.__included_sources) return ret def merge(self, ns_catalog): for name, namespace in ns_catalog.__namespaces.items(): self.add_namespace(name, namespace) @property @docval(returns='a tuple of the available namespaces', rtype=tuple) def namespaces(self): """The namespaces in this NamespaceCatalog""" return tuple(self.__namespaces.keys()) @property def dataset_spec_cls(self): """The DatasetSpec class used in this NamespaceCatalog""" return self.__dataset_spec_cls @property def group_spec_cls(self): """The GroupSpec class used in this NamespaceCatalog""" return self.__group_spec_cls @property def spec_namespace_cls(self): """The SpecNamespace class used in this NamespaceCatalog""" return self.__spec_namespace_cls @docval({'name': 'name', 'type': str, 'doc': 'the name of this namespace'}, {'name': 'namespace', 'type': SpecNamespace, 'doc': 'the SpecNamespace object'}) def add_namespace(self, **kwargs): """Add a namespace to this catalog""" name, namespace = getargs('name', 'namespace', kwargs) if name in self.__namespaces: raise KeyError("namespace '%s' already exists" % name) self.__namespaces[name] = namespace for dt in namespace.catalog.get_registered_types(): source = namespace.catalog.get_spec_source_file(dt) # do not add types that have already been loaded # use dict with None values as ordered set because order of specs does matter self.__loaded_specs.setdefault(source, dict()).update({dt: None}) @docval({'name': 'name', 'type': str, 'doc': 'the name of this namespace'}, returns="the SpecNamespace with the given name", rtype=SpecNamespace) def get_namespace(self, **kwargs): """Get the a SpecNamespace""" name = getargs('name', kwargs) ret = self.__namespaces.get(name) if ret is None: raise KeyError("'%s' not a namespace" % name) return ret @docval({'name': 'namespace', 'type': str, 'doc': 'the name of the namespace'}, {'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the spec for'}, returns="the specification for writing the given object type to HDF5 ", rtype='Spec') def get_spec(self, **kwargs): ''' Get the Spec object for the given type from the given Namespace ''' namespace, data_type = getargs('namespace', 'data_type', kwargs) if namespace not in self.__namespaces: raise KeyError("'%s' not a namespace" % namespace) return self.__namespaces[namespace].get_spec(data_type) @docval({'name': 'namespace', 'type': str, 'doc': 'the name of the namespace'}, {'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to get the spec for'}, returns="a tuple with the type hierarchy", rtype=tuple) def get_hierarchy(self, **kwargs): ''' Get the type hierarchy for a given data_type in a given namespace ''' namespace, data_type = getargs('namespace', 'data_type', kwargs) spec_ns = self.__namespaces.get(namespace) if spec_ns is None: raise KeyError("'%s' not a namespace" % namespace) return spec_ns.get_hierarchy(data_type) @docval({'name': 'namespace', 'type': str, 'doc': 'the name of the namespace containing the data_type'}, {'name': 'data_type', 'type': str, 'doc': 'the data_type to check'}, {'name': 'parent_data_type', 'type': str, 'doc': 'the potential parent data_type'}, returns="True if *data_type* is a sub `data_type` of *parent_data_type*, False otherwise", rtype=bool) def is_sub_data_type(self, **kwargs): ''' Return whether or not *data_type* is a sub `data_type` of *parent_data_type* ''' ns, dt, parent_dt = getargs('namespace', 'data_type', 'parent_data_type', kwargs) hier = self.get_hierarchy(ns, dt) return parent_dt in hier @docval(rtype=tuple) def get_sources(self, **kwargs): ''' Get all the source specification files that were loaded in this catalog ''' return tuple(self.__loaded_specs.keys()) @docval({'name': 'namespace', 'type': str, 'doc': 'the name of the namespace'}, rtype=tuple) def get_namespace_sources(self, **kwargs): ''' Get all the source specifications that were loaded for a given namespace ''' namespace = getargs('namespace', kwargs) return tuple(self.__included_sources[namespace]) @docval({'name': 'source', 'type': str, 'doc': 'the name of the source'}, rtype=tuple) def get_types(self, **kwargs): ''' Get the types that were loaded from a given source ''' source = getargs('source', kwargs) ret = self.__loaded_specs.get(source) if ret is not None: ret = tuple(ret) else: ret = tuple() return ret def __load_spec_file(self, reader, spec_source, catalog, types_to_load=None, resolve=True): ret = self.__loaded_specs.get(spec_source) if ret is not None: raise ValueError("spec source '%s' already loaded" % spec_source) def __reg_spec(spec_cls, spec_dict): dt_def = spec_dict.get(spec_cls.def_key()) if dt_def is None: msg = 'No data type def key found in spec %s' % spec_source raise ValueError(msg) if types_to_load and dt_def not in types_to_load: return if resolve: self.__resolve_includes(spec_cls, spec_dict, catalog) spec_obj = spec_cls.build_spec(spec_dict) return catalog.auto_register(spec_obj, spec_source) if ret is None: ret = dict() # this is used as an ordered set -- values are all none d = reader.read_spec(spec_source) specs = d.get('datasets', list()) for spec_dict in specs: self.__convert_spec_cls_keys(GroupSpec, self.__group_spec_cls, spec_dict) temp_dict = {k: None for k in __reg_spec(self.__dataset_spec_cls, spec_dict)} ret.update(temp_dict) specs = d.get('groups', list()) for spec_dict in specs: self.__convert_spec_cls_keys(GroupSpec, self.__group_spec_cls, spec_dict) temp_dict = {k: None for k in __reg_spec(self.__group_spec_cls, spec_dict)} ret.update(temp_dict) self.__loaded_specs[spec_source] = ret return ret def __convert_spec_cls_keys(self, parent_cls, spec_cls, spec_dict): """Replace instances of data_type_def/inc in spec_dict with new values from spec_cls.""" # this is necessary because the def_key and inc_key may be different in each namespace # NOTE: this does not handle more than one custom set of keys if parent_cls.def_key() in spec_dict: spec_dict[spec_cls.def_key()] = spec_dict.pop(parent_cls.def_key()) if parent_cls.inc_key() in spec_dict: spec_dict[spec_cls.inc_key()] = spec_dict.pop(parent_cls.inc_key()) def __resolve_includes(self, spec_cls, spec_dict, catalog): """Replace data type inc strings with the spec definition so the new spec is built with included fields. """ dt_def = spec_dict.get(spec_cls.def_key()) dt_inc = spec_dict.get(spec_cls.inc_key()) if dt_inc is not None and dt_def is not None: parent_spec = catalog.get_spec(dt_inc) if parent_spec is None: msg = "Cannot resolve include spec '%s' for type '%s'" % (dt_inc, dt_def) raise ValueError(msg) # replace the inc key value from string to the inc spec so that the spec can be updated with all of the # attributes, datasets, groups, and links of the inc spec when spec_cls.build_spec(spec_dict) is called spec_dict[spec_cls.inc_key()] = parent_spec for subspec_dict in spec_dict.get('groups', list()): self.__resolve_includes(self.__group_spec_cls, subspec_dict, catalog) for subspec_dict in spec_dict.get('datasets', list()): self.__resolve_includes(self.__dataset_spec_cls, subspec_dict, catalog) def __load_namespace(self, namespace, reader, resolve=True): ns_name = namespace['name'] if ns_name in self.__namespaces: # pragma: no cover raise KeyError("namespace '%s' already exists" % ns_name) catalog = SpecCatalog() included_types = dict() for s in namespace['schema']: # types_key may be different in each spec namespace, so check both the __spec_namespace_cls types key # and the parent SpecNamespace types key. NOTE: this does not handle more than one custom types key types_to_load = s.get(self.__spec_namespace_cls.types_key(), s.get(SpecNamespace.types_key())) if types_to_load is not None: # schema specifies specific types from 'source' or 'namespace' types_to_load = set(types_to_load) if 'source' in s: # read specs from file self.__load_spec_file(reader, s['source'], catalog, types_to_load=types_to_load, resolve=resolve) self.__included_sources.setdefault(ns_name, list()).append(s['source']) elif 'namespace' in s: # load specs from namespace try: inc_ns = self.get_namespace(s['namespace']) except KeyError as e: raise ValueError("Could not load namespace '%s'" % s['namespace']) from e if types_to_load is None: types_to_load = inc_ns.get_registered_types() # load all types in namespace registered_types = set() for ndt in types_to_load: self.__register_type(ndt, inc_ns, catalog, registered_types) included_types[s['namespace']] = tuple(sorted(registered_types)) else: raise ValueError("Spec '%s' schema must have either 'source' or 'namespace' key" % ns_name) # construct namespace ns = self.__spec_namespace_cls.build_namespace(catalog=catalog, **namespace) self.__namespaces[ns_name] = ns return included_types def __register_type(self, ndt, inc_ns, catalog, registered_types): spec = inc_ns.get_spec(ndt) spec_file = inc_ns.catalog.get_spec_source_file(ndt) self.__register_dependent_types(spec, inc_ns, catalog, registered_types) if isinstance(spec, DatasetSpec): built_spec = self.dataset_spec_cls.build_spec(spec) else: built_spec = self.group_spec_cls.build_spec(spec) registered_types.add(ndt) catalog.register_spec(built_spec, spec_file) def __register_dependent_types(self, spec, inc_ns, catalog, registered_types): """Ensure that classes for all types used by this type are registered """ # TODO test cross-namespace registration... def __register_dependent_types_helper(spec, inc_ns, catalog, registered_types): if isinstance(spec, (GroupSpec, DatasetSpec)): if spec.data_type_inc is not None: # TODO handle recursive definitions self.__register_type(spec.data_type_inc, inc_ns, catalog, registered_types) if spec.data_type_def is not None: # nested type definition self.__register_type(spec.data_type_def, inc_ns, catalog, registered_types) else: # spec is a LinkSpec self.__register_type(spec.target_type, inc_ns, catalog, registered_types) if isinstance(spec, GroupSpec): for child_spec in (spec.groups + spec.datasets + spec.links): __register_dependent_types_helper(child_spec, inc_ns, catalog, registered_types) if spec.data_type_inc is not None: self.__register_type(spec.data_type_inc, inc_ns, catalog, registered_types) if isinstance(spec, GroupSpec): for child_spec in (spec.groups + spec.datasets + spec.links): __register_dependent_types_helper(child_spec, inc_ns, catalog, registered_types) @docval({'name': 'namespace_path', 'type': str, 'doc': 'the path to the file containing the namespaces(s) to load'}, {'name': 'resolve', 'type': bool, 'doc': 'whether or not to include objects from included/parent spec objects', 'default': True}, {'name': 'reader', 'type': SpecReader, 'doc': 'the class to user for reading specifications', 'default': None}, returns='a dictionary describing the dependencies of loaded namespaces', rtype=dict) def load_namespaces(self, **kwargs): """Load the namespaces in the given file""" namespace_path, resolve, reader = getargs('namespace_path', 'resolve', 'reader', kwargs) if reader is None: # load namespace definition from file if not os.path.exists(namespace_path): msg = "namespace file '%s' not found" % namespace_path raise IOError(msg) reader = YAMLSpecReader(indir=os.path.dirname(namespace_path)) ns_path_key = os.path.join(reader.source, os.path.basename(namespace_path)) ret = self.__included_specs.get(ns_path_key) if ret is None: ret = dict() else: return ret namespaces = reader.read_namespace(namespace_path) to_load = list() for ns in namespaces: if ns['name'] in self.__namespaces: if ns['version'] != self.__namespaces.get(ns['name'])['version']: # warn if the cached namespace differs from the already loaded namespace warn("Ignoring cached namespace '%s' version %s because version %s is already loaded." % (ns['name'], ns['version'], self.__namespaces.get(ns['name'])['version'])) else: to_load.append(ns) # now load specs into namespace for ns in to_load: ret[ns['name']] = self.__load_namespace(ns, reader, resolve=resolve) self.__included_specs[ns_path_key] = ret return ret ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/spec/spec.py0000644000655200065520000014753700000000000016545 0ustar00circlecicircleciimport re from abc import ABCMeta from collections import OrderedDict from copy import deepcopy from warnings import warn from ..utils import docval, getargs, popargs, get_docval, fmt_docval_args NAME_WILDCARD = None ZERO_OR_ONE = '?' ZERO_OR_MANY = '*' ONE_OR_MANY = '+' DEF_QUANTITY = 1 FLAGS = { 'zero_or_one': ZERO_OR_ONE, 'zero_or_many': ZERO_OR_MANY, 'one_or_many': ONE_OR_MANY } class DtypeHelper: # Dict where the keys are the primary data type and the values are list of strings with synonyms for the dtype # make sure keys are consistent between hdmf.spec.spec.DtypeHelper.primary_dtype_synonyms, # hdmf.build.objectmapper.ObjectMapper.__dtypes, hdmf.build.manager.TypeMap._spec_dtype_map, # hdmf.validate.validator.__allowable, and backend dtype maps # see https://hdmf-schema-language.readthedocs.io/en/latest/description.html#dtype primary_dtype_synonyms = { 'float': ["float", "float32"], 'double': ["double", "float64"], 'short': ["int16", "short"], 'int': ["int32", "int"], 'long': ["int64", "long"], 'utf': ["text", "utf", "utf8", "utf-8"], 'ascii': ["ascii", "bytes"], 'bool': ["bool"], 'int8': ["int8"], 'uint8': ["uint8"], 'uint16': ["uint16"], 'uint32': ["uint32", "uint"], 'uint64': ["uint64"], 'object': ['object'], 'region': ['region'], 'numeric': ['numeric'], 'isodatetime': ["isodatetime", "datetime"] } # List of recommended primary dtype strings. These are the keys of primary_dtype_string_synonyms recommended_primary_dtypes = list(primary_dtype_synonyms.keys()) # List of valid primary data type strings valid_primary_dtypes = set(list(primary_dtype_synonyms.keys()) + [vi for v in primary_dtype_synonyms.values() for vi in v]) @staticmethod def simplify_cpd_type(cpd_type): ''' Transform a list of DtypeSpecs into a list of strings. Use for simple representation of compound type and validation. :param cpd_type: The list of DtypeSpecs to simplify :type cpd_type: list ''' ret = list() for exp in cpd_type: exp_key = exp.dtype if isinstance(exp_key, RefSpec): exp_key = exp_key.reftype ret.append(exp_key) return ret @staticmethod def check_dtype(dtype): """Check that the dtype string is a reference or a valid primary dtype.""" if not isinstance(dtype, RefSpec) and dtype not in DtypeHelper.valid_primary_dtypes: raise ValueError("dtype '%s' is not a valid primary data type. Allowed dtypes: %s" % (dtype, str(DtypeHelper.valid_primary_dtypes))) return dtype class ConstructableDict(dict, metaclass=ABCMeta): @classmethod def build_const_args(cls, spec_dict): ''' Build constructor arguments for this ConstructableDict class from a dictionary ''' # main use cases are when spec_dict is a ConstructableDict or a spec dict read from a file return deepcopy(spec_dict) @classmethod def build_spec(cls, spec_dict): ''' Build a Spec object from the given Spec dict ''' # main use cases are when spec_dict is a ConstructableDict or a spec dict read from a file vargs = cls.build_const_args(spec_dict) kwargs = dict() # iterate through the Spec docval and construct kwargs based on matching values in spec_dict for x in get_docval(cls.__init__): if x['name'] in vargs: kwargs[x['name']] = vargs.get(x['name']) return cls(**kwargs) class Spec(ConstructableDict): ''' A base specification class ''' @docval({'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, {'name': 'name', 'type': str, 'doc': 'The name of this attribute', 'default': None}, {'name': 'required', 'type': bool, 'doc': 'whether or not this attribute is required', 'default': True}, {'name': 'parent', 'type': 'Spec', 'doc': 'the parent of this spec', 'default': None}) def __init__(self, **kwargs): name, doc, required, parent = getargs('name', 'doc', 'required', 'parent', kwargs) super().__init__() self['doc'] = doc if name is not None: self['name'] = name if not required: self['required'] = required self._parent = parent @property def doc(self): ''' Documentation on what this Spec is specifying ''' return self.get('doc', None) @property def name(self): ''' The name of the object being specified ''' return self.get('name', None) @property def parent(self): ''' The parent specification of this specification ''' return self._parent @parent.setter def parent(self, spec): ''' Set the parent of this specification ''' if self._parent is not None: raise AttributeError('Cannot re-assign parent.') self._parent = spec @classmethod def build_const_args(cls, spec_dict): ''' Build constructor arguments for this Spec class from a dictionary ''' ret = super().build_const_args(spec_dict) return ret def __hash__(self): return id(self) @property def path(self): stack = list() tmp = self while tmp is not None: name = tmp.name if name is None: name = tmp.data_type_def if name is None: name = tmp.data_type_inc stack.append(name) tmp = tmp.parent return "/".join(reversed(stack)) # def __eq__(self, other): # return id(self) == id(other) _target_type_key = 'target_type' _ref_args = [ {'name': _target_type_key, 'type': str, 'doc': 'the target type GroupSpec or DatasetSpec'}, {'name': 'reftype', 'type': str, 'doc': 'the type of references this is i.e. region or object'}, ] class RefSpec(ConstructableDict): __allowable_types = ('object', 'region') @docval(*_ref_args) def __init__(self, **kwargs): target_type, reftype = getargs(_target_type_key, 'reftype', kwargs) self[_target_type_key] = target_type if reftype not in self.__allowable_types: msg = "reftype must be one of the following: %s" % ", ".join(self.__allowable_types) raise ValueError(msg) self['reftype'] = reftype @property def target_type(self): '''The data_type of the target of the reference''' return self[_target_type_key] @property def reftype(self): '''The type of reference''' return self['reftype'] @docval(rtype=bool, returns='True if this RefSpec specifies a region reference, False otherwise') def is_region(self): return self['reftype'] == 'region' _attr_args = [ {'name': 'name', 'type': str, 'doc': 'The name of this attribute'}, {'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, {'name': 'dtype', 'type': (str, RefSpec), 'doc': 'The data type of this attribute'}, {'name': 'shape', 'type': (list, tuple), 'doc': 'the shape of this dataset', 'default': None}, {'name': 'dims', 'type': (list, tuple), 'doc': 'the dimensions of this dataset', 'default': None}, {'name': 'required', 'type': bool, 'doc': 'whether or not this attribute is required. ignored when "value" is specified', 'default': True}, {'name': 'parent', 'type': 'BaseStorageSpec', 'doc': 'the parent of this spec', 'default': None}, {'name': 'value', 'type': None, 'doc': 'a constant value for this attribute', 'default': None}, {'name': 'default_value', 'type': None, 'doc': 'a default value for this attribute', 'default': None} ] class AttributeSpec(Spec): ''' Specification for attributes ''' @docval(*_attr_args) def __init__(self, **kwargs): name, dtype, doc, dims, shape, required, parent, value, default_value = getargs( 'name', 'dtype', 'doc', 'dims', 'shape', 'required', 'parent', 'value', 'default_value', kwargs) super().__init__(doc, name=name, required=required, parent=parent) self['dtype'] = DtypeHelper.check_dtype(dtype) if value is not None: self.pop('required', None) self['value'] = value if default_value is not None: if value is not None: raise ValueError("cannot specify 'value' and 'default_value'") self['default_value'] = default_value self['required'] = False if shape is not None: self['shape'] = shape if dims is not None: self['dims'] = dims if 'shape' not in self: self['shape'] = tuple([None] * len(dims)) if self.shape is not None and self.dims is not None: if len(self['dims']) != len(self['shape']): raise ValueError("'dims' and 'shape' must be the same length") @property def dtype(self): ''' The data type of the attribute ''' return self.get('dtype', None) @property def value(self): ''' The constant value of the attribute. "None" if this attribute is not constant ''' return self.get('value', None) @property def default_value(self): ''' The default value of the attribute. "None" if this attribute has no default value ''' return self.get('default_value', None) @property def required(self): ''' True if this attribute is required, False otherwise. ''' return self.get('required', True) @property def dims(self): ''' The dimensions of this attribute's value ''' return self.get('dims', None) @property def shape(self): ''' The shape of this attribute's value ''' return self.get('shape', None) @classmethod def build_const_args(cls, spec_dict): ''' Build constructor arguments for this Spec class from a dictionary ''' ret = super().build_const_args(spec_dict) if isinstance(ret['dtype'], dict): ret['dtype'] = RefSpec.build_spec(ret['dtype']) return ret _attrbl_args = [ {'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, {'name': 'name', 'type': str, 'doc': 'the name of this base storage container, allowed only if quantity is not \'%s\' or \'%s\'' % (ONE_OR_MANY, ZERO_OR_MANY), 'default': None}, {'name': 'default_name', 'type': str, 'doc': 'The default name of this base storage container, used only if name is None', 'default': None}, {'name': 'attributes', 'type': list, 'doc': 'the attributes on this group', 'default': list()}, {'name': 'linkable', 'type': bool, 'doc': 'whether or not this group can be linked', 'default': True}, {'name': 'quantity', 'type': (str, int), 'doc': 'the required number of allowed instance', 'default': 1}, {'name': 'data_type_def', 'type': str, 'doc': 'the data type this specification represents', 'default': None}, {'name': 'data_type_inc', 'type': (str, 'BaseStorageSpec'), 'doc': 'the data type this specification extends', 'default': None}, ] class BaseStorageSpec(Spec): ''' A specification for any object that can hold attributes. ''' __inc_key = 'data_type_inc' __def_key = 'data_type_def' __type_key = 'data_type' __id_key = 'object_id' @docval(*_attrbl_args) def __init__(self, **kwargs): name, doc, quantity, attributes, linkable, data_type_def, data_type_inc = \ getargs('name', 'doc', 'quantity', 'attributes', 'linkable', 'data_type_def', 'data_type_inc', kwargs) if name == NAME_WILDCARD and data_type_def is None and data_type_inc is None: raise ValueError("Cannot create Group or Dataset spec with no name " "without specifying '%s' and/or '%s'." % (self.def_key(), self.inc_key())) super().__init__(doc, name=name) default_name = getargs('default_name', kwargs) if default_name: if name is not None: warn("found 'default_name' with 'name' - ignoring 'default_name'") else: self['default_name'] = default_name self.__attributes = dict() if quantity in (ONE_OR_MANY, ZERO_OR_MANY): if name != NAME_WILDCARD: raise ValueError("Cannot give specific name to something that can " "exist multiple times: name='%s', quantity='%s'" % (name, quantity)) if quantity != DEF_QUANTITY: self['quantity'] = quantity if not linkable: self['linkable'] = False resolve = False if data_type_inc is not None: if isinstance(data_type_inc, BaseStorageSpec): self[self.inc_key()] = data_type_inc.data_type_def else: self[self.inc_key()] = data_type_inc if data_type_def is not None: self.pop('required', None) self[self.def_key()] = data_type_def # resolve inherited and overridden fields only if data_type_inc is a spec # NOTE: this does not happen when loading specs from a file if data_type_inc is not None and isinstance(data_type_inc, BaseStorageSpec): resolve = True # self.attributes / self['attributes']: tuple/list of attributes # self.__attributes: dict of all attributes, including attributes from parent (data_type_inc) types # self.__new_attributes: set of attribute names that do not exist in the parent type # self.__overridden_attributes: set of attribute names that exist in this spec and the parent type # self.__new_attributes and self.__overridden_attributes are only set properly if resolve = True for attribute in attributes: self.set_attribute(attribute) self.__new_attributes = set(self.__attributes.keys()) self.__overridden_attributes = set() self.__resolved = False if resolve: self.resolve_spec(data_type_inc) @property def default_name(self): '''The default name for this spec''' return self.get('default_name', None) @property def resolved(self): return self.__resolved @property def required(self): ''' Whether or not the this spec represents a required field ''' return self.quantity not in (ZERO_OR_ONE, ZERO_OR_MANY) @docval({'name': 'inc_spec', 'type': 'BaseStorageSpec', 'doc': 'the data type this specification represents'}) def resolve_spec(self, **kwargs): """Add attributes from the inc_spec to this spec and track which attributes are new and overridden.""" inc_spec = getargs('inc_spec', kwargs) for attribute in inc_spec.attributes: self.__new_attributes.discard(attribute.name) if attribute.name in self.__attributes: self.__overridden_attributes.add(attribute.name) else: self.set_attribute(attribute) self.__resolved = True @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to check'}) def is_inherited_spec(self, **kwargs): ''' Return True if this spec was inherited from the parent type, False otherwise. Returns False if the spec is not found. ''' spec = getargs('spec', kwargs) if isinstance(spec, Spec): spec = spec.name if spec in self.__attributes: return self.is_inherited_attribute(spec) return False @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to check'}) def is_overridden_spec(self, **kwargs): ''' Return True if this spec overrides a specification from the parent type, False otherwise. Returns False if the spec is not found. ''' spec = getargs('spec', kwargs) if isinstance(spec, Spec): spec = spec.name if spec in self.__attributes: return self.is_overridden_attribute(spec) return False @docval({'name': 'name', 'type': str, 'doc': 'the name of the attribute to check'}) def is_inherited_attribute(self, **kwargs): ''' Return True if the attribute was inherited from the parent type, False otherwise. Raises a ValueError if the spec is not found. ''' name = getargs('name', kwargs) if name not in self.__attributes: raise ValueError("Attribute '%s' not found" % name) return name not in self.__new_attributes @docval({'name': 'name', 'type': str, 'doc': 'the name of the attribute to check'}) def is_overridden_attribute(self, **kwargs): ''' Return True if the given attribute overrides the specification from the parent, False otherwise. Raises a ValueError if the spec is not found. ''' name = getargs('name', kwargs) if name not in self.__attributes: raise ValueError("Attribute '%s' not found" % name) return name in self.__overridden_attributes def is_many(self): return self.quantity not in (1, ZERO_OR_ONE) @classmethod def get_data_type_spec(cls, data_type_def): # unused return AttributeSpec(cls.type_key(), 'the data type of this object', 'text', value=data_type_def) @classmethod def get_namespace_spec(cls): # unused return AttributeSpec('namespace', 'the namespace for the data type of this object', 'text', required=False) @property def attributes(self): ''' Tuple of attribute specifications for this specification ''' return tuple(self.get('attributes', tuple())) @property def linkable(self): ''' True if object can be a link, False otherwise ''' return self.get('linkable', True) @classmethod def id_key(cls): ''' Get the key used to store data ID on an instance Override this method to use a different name for 'object_id' ''' return cls.__id_key @classmethod def type_key(cls): ''' Get the key used to store data type on an instance Override this method to use a different name for 'data_type'. HDMF supports combining schema that uses 'data_type' and at most one different name for 'data_type'. ''' return cls.__type_key @classmethod def inc_key(cls): ''' Get the key used to define a data_type include. Override this method to use a different keyword for 'data_type_inc'. HDMF supports combining schema that uses 'data_type_inc' and at most one different name for 'data_type_inc'. ''' return cls.__inc_key @classmethod def def_key(cls): ''' Get the key used to define a data_type definition. Override this method to use a different keyword for 'data_type_def' HDMF supports combining schema that uses 'data_type_def' and at most one different name for 'data_type_def'. ''' return cls.__def_key @property def data_type_inc(self): ''' The data type this specification inherits ''' return self.get(self.inc_key()) @property def data_type_def(self): ''' The data type this specification defines ''' return self.get(self.def_key(), None) @property def data_type(self): ''' The data type of this specification ''' return self.data_type_def or self.data_type_inc @property def quantity(self): ''' The number of times the object being specified should be present ''' return self.get('quantity', DEF_QUANTITY) @docval(*_attr_args) def add_attribute(self, **kwargs): ''' Add an attribute to this specification ''' pargs, pkwargs = fmt_docval_args(AttributeSpec.__init__, kwargs) spec = AttributeSpec(*pargs, **pkwargs) self.set_attribute(spec) return spec @docval({'name': 'spec', 'type': AttributeSpec, 'doc': 'the specification for the attribute to add'}) def set_attribute(self, **kwargs): ''' Set an attribute on this specification ''' spec = kwargs.get('spec') attributes = self.setdefault('attributes', list()) if spec.parent is not None: spec = AttributeSpec.build_spec(spec) # if attribute name already exists in self.__attributes, # 1. find the attribute in self['attributes'] list and replace it with the given spec # 2. replace the value for the name key in the self.__attributes dict # otherwise, add the attribute spec to the self['attributes'] list and self.__attributes dict # the values of self['attributes'] and self.__attributes should always be the same # the former enables the spec to act like a dict with the 'attributes' key and # the latter is useful for name-based access of attributes if spec.name in self.__attributes: idx = -1 for i, attribute in enumerate(attributes): # pragma: no cover (execution should break) if attribute.name == spec.name: idx = i break if idx >= 0: attributes[idx] = spec else: # pragma: no cover raise ValueError('%s in __attributes but not in spec record' % spec.name) else: attributes.append(spec) self.__attributes[spec.name] = spec spec.parent = self @docval({'name': 'name', 'type': str, 'doc': 'the name of the attribute to the Spec for'}) def get_attribute(self, **kwargs): ''' Get an attribute on this specification ''' name = getargs('name', kwargs) return self.__attributes.get(name) @classmethod def build_const_args(cls, spec_dict): ''' Build constructor arguments for this Spec class from a dictionary ''' ret = super().build_const_args(spec_dict) if 'attributes' in ret: ret['attributes'] = [AttributeSpec.build_spec(sub_spec) for sub_spec in ret['attributes']] return ret _dt_args = [ {'name': 'name', 'type': str, 'doc': 'the name of this column'}, {'name': 'doc', 'type': str, 'doc': 'a description about what this data type is'}, {'name': 'dtype', 'type': (str, list, RefSpec), 'doc': 'the data type of this column'}, ] class DtypeSpec(ConstructableDict): '''A class for specifying a component of a compound type''' @docval(*_dt_args) def __init__(self, **kwargs): doc, name, dtype = getargs('doc', 'name', 'dtype', kwargs) self['doc'] = doc self['name'] = name self.check_valid_dtype(dtype) self['dtype'] = dtype @property def doc(self): '''Documentation about this component''' return self['doc'] @property def name(self): '''The name of this component''' return self['name'] @property def dtype(self): ''' The data type of this component''' return self['dtype'] @staticmethod def assertValidDtype(dtype): # Calls check_valid_dtype. This method is maintained for backwards compatibility return DtypeSpec.check_valid_dtype(dtype) @staticmethod def check_valid_dtype(dtype): if isinstance(dtype, dict): if _target_type_key not in dtype: msg = "'dtype' must have the key '%s'" % _target_type_key raise ValueError(msg) else: DtypeHelper.check_dtype(dtype) return True @staticmethod @docval({'name': 'spec', 'type': (str, dict), 'doc': 'the spec object to check'}, is_method=False) def is_ref(**kwargs): spec = getargs('spec', kwargs) spec_is_ref = False if isinstance(spec, dict): if _target_type_key in spec: spec_is_ref = True elif 'dtype' in spec and isinstance(spec['dtype'], dict) and _target_type_key in spec['dtype']: spec_is_ref = True return spec_is_ref @classmethod def build_const_args(cls, spec_dict): ''' Build constructor arguments for this Spec class from a dictionary ''' ret = super().build_const_args(spec_dict) if isinstance(ret['dtype'], list): ret['dtype'] = list(map(cls.build_const_args, ret['dtype'])) elif isinstance(ret['dtype'], dict): ret['dtype'] = RefSpec.build_spec(ret['dtype']) return ret _dataset_args = [ {'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, {'name': 'dtype', 'type': (str, list, RefSpec), 'doc': 'The data type of this attribute. Use a list of DtypeSpecs to specify a compound data type.', 'default': None}, {'name': 'name', 'type': str, 'doc': 'The name of this dataset', 'default': None}, {'name': 'default_name', 'type': str, 'doc': 'The default name of this dataset', 'default': None}, {'name': 'shape', 'type': (list, tuple), 'doc': 'the shape of this dataset', 'default': None}, {'name': 'dims', 'type': (list, tuple), 'doc': 'the dimensions of this dataset', 'default': None}, {'name': 'attributes', 'type': list, 'doc': 'the attributes on this group', 'default': list()}, {'name': 'linkable', 'type': bool, 'doc': 'whether or not this group can be linked', 'default': True}, {'name': 'quantity', 'type': (str, int), 'doc': 'the required number of allowed instance', 'default': 1}, {'name': 'default_value', 'type': None, 'doc': 'a default value for this dataset', 'default': None}, {'name': 'data_type_def', 'type': str, 'doc': 'the data type this specification represents', 'default': None}, {'name': 'data_type_inc', 'type': (str, 'DatasetSpec'), 'doc': 'the data type this specification extends', 'default': None}, ] class DatasetSpec(BaseStorageSpec): ''' Specification for datasets To specify a table-like dataset i.e. a compound data type. ''' @docval(*_dataset_args) def __init__(self, **kwargs): doc, shape, dims, dtype, default_value = popargs('doc', 'shape', 'dims', 'dtype', 'default_value', kwargs) if shape is not None: self['shape'] = shape if dims is not None: self['dims'] = dims if 'shape' not in self: self['shape'] = tuple([None] * len(dims)) if self.shape is not None and self.dims is not None: if len(self['dims']) != len(self['shape']): raise ValueError("'dims' and 'shape' must be the same length") if dtype is not None: if isinstance(dtype, list): # Dtype is a compound data type for _i, col in enumerate(dtype): if not isinstance(col, DtypeSpec): msg = ('must use DtypeSpec if defining compound dtype - found %s at element %d' % (type(col), _i)) raise ValueError(msg) else: DtypeHelper.check_dtype(dtype) self['dtype'] = dtype super().__init__(doc, **kwargs) if default_value is not None: self['default_value'] = default_value if self.name is not None: valid_quant_vals = [1, 'zero_or_one', ZERO_OR_ONE] if self.quantity not in valid_quant_vals: raise ValueError("quantity %s invalid for spec with fixed name. Valid values are: %s" % (self.quantity, str(valid_quant_vals))) @classmethod def __get_prec_level(cls, dtype): m = re.search('[0-9]+', dtype) if m is not None: prec = int(m.group()) else: prec = 32 return (dtype[0], prec) @classmethod def __is_sub_dtype(cls, orig, new): if isinstance(orig, RefSpec): if not isinstance(new, RefSpec): return False return orig == new else: orig_prec = cls.__get_prec_level(orig) new_prec = cls.__get_prec_level(new) if orig_prec[0] != new_prec[0]: # cannot extend int to float and vice-versa return False return new_prec >= orig_prec @docval({'name': 'inc_spec', 'type': 'DatasetSpec', 'doc': 'the data type this specification represents'}) def resolve_spec(self, **kwargs): inc_spec = getargs('inc_spec', kwargs) if isinstance(self.dtype, list): # merge the new types inc_dtype = inc_spec.dtype if isinstance(inc_dtype, str): msg = 'Cannot extend simple data type to compound data type' raise ValueError(msg) order = OrderedDict() if inc_dtype is not None: for dt in inc_dtype: order[dt['name']] = dt for dt in self.dtype: name = dt['name'] if name in order: # verify that the exension has supplied # a valid subtyping of existing type orig = order[name].dtype new = dt.dtype if not self.__is_sub_dtype(orig, new): msg = 'Cannot extend %s to %s' % (str(orig), str(new)) raise ValueError(msg) order[name] = dt self['dtype'] = list(order.values()) super().resolve_spec(inc_spec) @property def dims(self): ''' The dimensions of this Dataset ''' return self.get('dims', None) @property def dtype(self): ''' The data type of the Dataset ''' return self.get('dtype', None) @property def shape(self): ''' The shape of the dataset ''' return self.get('shape', None) @property def default_value(self): '''The default value of the dataset or None if not specified''' return self.get('default_value', None) @classmethod def __check_dim(cls, dim, data): return True @classmethod def dtype_spec_cls(cls): ''' The class to use when constructing DtypeSpec objects Override this if extending to use a class other than DtypeSpec to build dataset specifications ''' return DtypeSpec @classmethod def build_const_args(cls, spec_dict): ''' Build constructor arguments for this Spec class from a dictionary ''' ret = super().build_const_args(spec_dict) if 'dtype' in ret: if isinstance(ret['dtype'], list): ret['dtype'] = list(map(cls.dtype_spec_cls().build_spec, ret['dtype'])) elif isinstance(ret['dtype'], dict): ret['dtype'] = RefSpec.build_spec(ret['dtype']) return ret _link_args = [ {'name': 'doc', 'type': str, 'doc': 'a description about what this link represents'}, {'name': _target_type_key, 'type': (str, BaseStorageSpec), 'doc': 'the target type GroupSpec or DatasetSpec'}, {'name': 'quantity', 'type': (str, int), 'doc': 'the required number of allowed instance', 'default': 1}, {'name': 'name', 'type': str, 'doc': 'the name of this link', 'default': None} ] class LinkSpec(Spec): @docval(*_link_args) def __init__(self, **kwargs): doc, target_type, name, quantity = popargs('doc', _target_type_key, 'name', 'quantity', kwargs) super().__init__(doc, name, **kwargs) if isinstance(target_type, BaseStorageSpec): if target_type.data_type_def is None: msg = ("'%s' must be a string or a GroupSpec or DatasetSpec with a '%s' key." % (_target_type_key, target_type.def_key())) raise ValueError(msg) self[_target_type_key] = target_type.data_type_def else: self[_target_type_key] = target_type if quantity != 1: self['quantity'] = quantity @property def target_type(self): ''' The data type of target specification ''' return self.get(_target_type_key) @property def data_type_inc(self): ''' The data type of target specification ''' return self.get(_target_type_key) def is_many(self): return self.quantity not in (1, ZERO_OR_ONE) @property def quantity(self): ''' The number of times the object being specified should be present ''' return self.get('quantity', DEF_QUANTITY) @property def required(self): ''' Whether or not the this spec represents a required field ''' return self.quantity not in (ZERO_OR_ONE, ZERO_OR_MANY) _group_args = [ {'name': 'doc', 'type': str, 'doc': 'a description about what this specification represents'}, { 'name': 'name', 'type': str, 'doc': 'the name of the Group that is written to the file. If this argument is omitted, users will be ' 'required to enter a ``name`` field when creating instances of this data type in the API. Another ' 'option is to specify ``default_name``, in which case this name will be used as the name of the Group ' 'if no other name is provided.', 'default': None, }, {'name': 'default_name', 'type': str, 'doc': 'The default name of this group', 'default': None}, {'name': 'groups', 'type': list, 'doc': 'the subgroups in this group', 'default': list()}, {'name': 'datasets', 'type': list, 'doc': 'the datasets in this group', 'default': list()}, {'name': 'attributes', 'type': list, 'doc': 'the attributes on this group', 'default': list()}, {'name': 'links', 'type': list, 'doc': 'the links in this group', 'default': list()}, {'name': 'linkable', 'type': bool, 'doc': 'whether or not this group can be linked', 'default': True}, { 'name': 'quantity', 'type': (str, int), 'doc': "the allowable number of instance of this group in a certain location. See table of options " "`here `_. Note that if you" "specify ``name``, ``quantity`` cannot be ``'*'``, ``'+'``, or an integer greater that 1, because you " "cannot have more than one group of the same name in the same parent group.", 'default': 1, }, {'name': 'data_type_def', 'type': str, 'doc': 'the data type this specification represents', 'default': None}, {'name': 'data_type_inc', 'type': (str, 'GroupSpec'), 'doc': 'the data type this specification data_type_inc', 'default': None}, ] class GroupSpec(BaseStorageSpec): ''' Specification for groups ''' @docval(*_group_args) def __init__(self, **kwargs): doc, groups, datasets, links = popargs('doc', 'groups', 'datasets', 'links', kwargs) self.__data_types = dict() self.__groups = dict() for group in groups: self.set_group(group) self.__datasets = dict() for dataset in datasets: self.set_dataset(dataset) self.__links = dict() for link in links: self.set_link(link) self.__new_data_types = set(self.__data_types.keys()) self.__new_datasets = set(self.__datasets.keys()) self.__overridden_datasets = set() self.__new_links = set(self.__links.keys()) self.__overridden_links = set() self.__new_groups = set(self.__groups.keys()) self.__overridden_groups = set() super().__init__(doc, **kwargs) @docval({'name': 'inc_spec', 'type': 'GroupSpec', 'doc': 'the data type this specification represents'}) def resolve_spec(self, **kwargs): inc_spec = getargs('inc_spec', kwargs) data_types = list() # resolve inherited datasets for dataset in inc_spec.datasets: # if not (dataset.data_type_def is None and dataset.data_type_inc is None): if dataset.name is None: data_types.append(dataset) continue self.__new_datasets.discard(dataset.name) if dataset.name in self.__datasets: self.__datasets[dataset.name].resolve_spec(dataset) self.__overridden_datasets.add(dataset.name) else: self.set_dataset(dataset) # resolve inherited groups for group in inc_spec.groups: # if not (group.data_type_def is None and group.data_type_inc is None): if group.name is None: data_types.append(group) continue self.__new_groups.discard(group.name) if group.name in self.__groups: self.__groups[group.name].resolve_spec(group) self.__overridden_groups.add(group.name) else: self.set_group(group) # resolve inherited links for link in inc_spec.links: if link.name is None: data_types.append(link) self.__new_links.discard(link.name) if link.name in self.__links: self.__overridden_links.add(link.name) else: self.set_link(link) # resolve inherited data_types for dt_spec in data_types: if isinstance(dt_spec, LinkSpec): dt = dt_spec.target_type else: dt = dt_spec.data_type_def if dt is None: dt = dt_spec.data_type_inc self.__new_data_types.discard(dt) existing_dt_spec = self.get_data_type(dt) if existing_dt_spec is None or \ ((isinstance(existing_dt_spec, list) or existing_dt_spec.name is not None)) and \ dt_spec.name is None: if isinstance(dt_spec, DatasetSpec): self.set_dataset(dt_spec) elif isinstance(dt_spec, GroupSpec): self.set_group(dt_spec) else: self.set_link(dt_spec) super().resolve_spec(inc_spec) @docval({'name': 'name', 'type': str, 'doc': 'the name of the dataset'}, raises="ValueError, if 'name' is not part of this spec") def is_inherited_dataset(self, **kwargs): '''Return true if a dataset with the given name was inherited''' name = getargs('name', kwargs) if name not in self.__datasets: raise ValueError("Dataset '%s' not found in spec" % name) return name not in self.__new_datasets @docval({'name': 'name', 'type': str, 'doc': 'the name of the dataset'}, raises="ValueError, if 'name' is not part of this spec") def is_overridden_dataset(self, **kwargs): '''Return true if a dataset with the given name overrides a specification from the parent type''' name = getargs('name', kwargs) if name not in self.__datasets: raise ValueError("Dataset '%s' not found in spec" % name) return name in self.__overridden_datasets @docval({'name': 'name', 'type': str, 'doc': 'the name of the group'}, raises="ValueError, if 'name' is not part of this spec") def is_inherited_group(self, **kwargs): '''Return true if a group with the given name was inherited''' name = getargs('name', kwargs) if name not in self.__groups: raise ValueError("Group '%s' not found in spec" % name) return name not in self.__new_groups @docval({'name': 'name', 'type': str, 'doc': 'the name of the group'}, raises="ValueError, if 'name' is not part of this spec") def is_overridden_group(self, **kwargs): '''Return true if a group with the given name overrides a specification from the parent type''' name = getargs('name', kwargs) if name not in self.__groups: raise ValueError("Group '%s' not found in spec" % name) return name in self.__overridden_groups @docval({'name': 'name', 'type': str, 'doc': 'the name of the link'}, raises="ValueError, if 'name' is not part of this spec") def is_inherited_link(self, **kwargs): '''Return true if a link with the given name was inherited''' name = getargs('name', kwargs) if name not in self.__links: raise ValueError("Link '%s' not found in spec" % name) return name not in self.__new_links @docval({'name': 'name', 'type': str, 'doc': 'the name of the link'}, raises="ValueError, if 'name' is not part of this spec") def is_overridden_link(self, **kwargs): '''Return true if a link with the given name overrides a specification from the parent type''' name = getargs('name', kwargs) if name not in self.__links: raise ValueError("Link '%s' not found in spec" % name) return name in self.__overridden_links @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to check'}) def is_inherited_spec(self, **kwargs): ''' Returns 'True' if specification was inherited from a parent type ''' spec = getargs('spec', kwargs) if isinstance(spec, Spec): name = spec.name if name is None: name = spec.data_type_def if name is None: name = spec.data_type_inc if name is None: raise ValueError('received Spec with wildcard name but no data_type_inc or data_type_def') spec = name if spec in self.__links: return self.is_inherited_link(spec) elif spec in self.__groups: return self.is_inherited_group(spec) elif spec in self.__datasets: return self.is_inherited_dataset(spec) elif spec in self.__data_types: return self.is_inherited_type(spec) else: if super().is_inherited_spec(spec): return True else: for s in self.__datasets: if self.is_inherited_dataset(s): if self.__datasets[s].get_attribute(spec) is not None: return True for s in self.__groups: if self.is_inherited_group(s): if self.__groups[s].get_attribute(spec) is not None: return True return False @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to check'}) def is_overridden_spec(self, **kwargs): ''' Returns 'True' if specification was inherited from a parent type ''' spec = getargs('spec', kwargs) if isinstance(spec, Spec): name = spec.name if name is None: if spec.is_many(): # this is a wildcard spec, so it cannot be overridden return False name = spec.data_type_def if name is None: name = spec.data_type_inc if name is None: raise ValueError('received Spec with wildcard name but no data_type_inc or data_type_def') spec = name if spec in self.__links: return self.is_overridden_link(spec) elif spec in self.__groups: return self.is_overridden_group(spec) elif spec in self.__datasets: return self.is_overridden_dataset(spec) elif spec in self.__data_types: return self.is_overridden_type(spec) else: if super().is_overridden_spec(spec): # check if overridden attribute return True else: for s in self.__datasets: if self.is_overridden_dataset(s): if self.__datasets[s].is_overridden_spec(spec): return True for s in self.__groups: if self.is_overridden_group(s): if self.__groups[s].is_overridden_spec(spec): return True return False @docval({'name': 'spec', 'type': (BaseStorageSpec, str), 'doc': 'the specification to check'}) def is_inherited_type(self, **kwargs): ''' Returns True if `spec` represents a spec that was inherited from an included data_type ''' spec = getargs('spec', kwargs) if isinstance(spec, BaseStorageSpec): if spec.data_type_def is None: raise ValueError('cannot check if something was inherited if it does not have a %s' % self.def_key()) spec = spec.data_type_def return spec not in self.__new_data_types @docval({'name': 'spec', 'type': (BaseStorageSpec, str), 'doc': 'the specification to check'}, raises="ValueError, if 'name' is not part of this spec") def is_overridden_type(self, **kwargs): ''' Returns True if `spec` represents a spec that was overriden by the subtype''' spec = getargs('spec', kwargs) if isinstance(spec, BaseStorageSpec): if spec.data_type_def is None: raise ValueError('cannot check if something was inherited if it does not have a %s' % self.def_key()) spec = spec.data_type_def return spec not in self.__new_data_types def __add_data_type_inc(self, spec): dt = None if hasattr(spec, 'data_type_def') and spec.data_type_def is not None: dt = spec.data_type_def elif hasattr(spec, 'data_type_inc') and spec.data_type_inc is not None: dt = spec.data_type_inc if not dt: raise TypeError("spec does not have '%s' or '%s' defined" % (self.def_key(), self.inc_key())) if dt in self.__data_types: curr = self.__data_types[dt] if curr is spec: return if spec.name is None: if isinstance(curr, list): self.__data_types[dt] = spec else: if curr.name is None: raise TypeError('Cannot have multiple data types of the same type without specifying name') else: # unnamed data types will be stored as data_types self.__data_types[dt] = spec else: if isinstance(curr, list): self.__data_types[dt].append(spec) else: if curr.name is None: # leave the existing data type as is, since the new one can be retrieved by name return else: # store both specific instances of a data type self.__data_types[dt] = [curr, spec] else: self.__data_types[dt] = spec @docval({'name': 'data_type', 'type': str, 'doc': 'the data_type to retrieve'}) def get_data_type(self, **kwargs): ''' Get a specification by "data_type" ''' ndt = getargs('data_type', kwargs) return self.__data_types.get(ndt, None) @property def groups(self): ''' The groups specificed in this GroupSpec ''' return tuple(self.get('groups', tuple())) @property def datasets(self): ''' The datasets specificed in this GroupSpec ''' return tuple(self.get('datasets', tuple())) @property def links(self): ''' The links specificed in this GroupSpec ''' return tuple(self.get('links', tuple())) @docval(*_group_args) def add_group(self, **kwargs): ''' Add a new specification for a subgroup to this group specification ''' doc = kwargs.pop('doc') spec = self.__class__(doc, **kwargs) self.set_group(spec) return spec @docval({'name': 'spec', 'type': ('GroupSpec'), 'doc': 'the specification for the subgroup'}) def set_group(self, **kwargs): ''' Add the given specification for a subgroup to this group specification ''' spec = getargs('spec', kwargs) if spec.parent is not None: spec = self.build_spec(spec) if spec.name == NAME_WILDCARD: if spec.data_type_inc is not None or spec.data_type_def is not None: self.__add_data_type_inc(spec) else: raise TypeError("must specify 'name' or 'data_type_inc' in Group spec") else: if spec.data_type_inc is not None or spec.data_type_def is not None: self.__add_data_type_inc(spec) self.__groups[spec.name] = spec self.setdefault('groups', list()).append(spec) spec.parent = self @docval({'name': 'name', 'type': str, 'doc': 'the name of the group to the Spec for'}) def get_group(self, **kwargs): ''' Get a specification for a subgroup to this group specification ''' name = getargs('name', kwargs) return self.__groups.get(name, self.__links.get(name)) @docval(*_dataset_args) def add_dataset(self, **kwargs): ''' Add a new specification for a dataset to this group specification ''' doc = kwargs.pop('doc') spec = self.dataset_spec_cls()(doc, **kwargs) self.set_dataset(spec) return spec @docval({'name': 'spec', 'type': 'DatasetSpec', 'doc': 'the specification for the dataset'}) def set_dataset(self, **kwargs): ''' Add the given specification for a dataset to this group specification ''' spec = getargs('spec', kwargs) if spec.parent is not None: spec = self.dataset_spec_cls().build_spec(spec) if spec.name == NAME_WILDCARD: if spec.data_type_inc is not None or spec.data_type_def is not None: self.__add_data_type_inc(spec) else: raise TypeError("must specify 'name' or 'data_type_inc' in Dataset spec") else: if spec.data_type_inc is not None or spec.data_type_def is not None: self.__add_data_type_inc(spec) self.__datasets[spec.name] = spec self.setdefault('datasets', list()).append(spec) spec.parent = self @docval({'name': 'name', 'type': str, 'doc': 'the name of the dataset to the Spec for'}) def get_dataset(self, **kwargs): ''' Get a specification for a dataset to this group specification ''' name = getargs('name', kwargs) return self.__datasets.get(name, self.__links.get(name)) @docval(*_link_args) def add_link(self, **kwargs): ''' Add a new specification for a link to this group specification ''' doc, target_type = popargs('doc', _target_type_key, kwargs) spec = self.link_spec_cls()(doc, target_type, **kwargs) self.set_link(spec) return spec @docval({'name': 'spec', 'type': 'LinkSpec', 'doc': 'the specification for the object to link to'}) def set_link(self, **kwargs): ''' Add a given specification for a link to this group specification ''' spec = getargs('spec', kwargs) if spec.parent is not None: spec = self.link_spec_cls().build_spec(spec) if spec.name != NAME_WILDCARD: self.__links[spec.name] = spec self.setdefault('links', list()).append(spec) spec.parent = self @docval({'name': 'name', 'type': str, 'doc': 'the name of the link to the Spec for'}) def get_link(self, **kwargs): ''' Get a specification for a link to this group specification ''' name = getargs('name', kwargs) return self.__links.get(name) @classmethod def dataset_spec_cls(cls): ''' The class to use when constructing DatasetSpec objects Override this if extending to use a class other than DatasetSpec to build dataset specifications ''' return DatasetSpec @classmethod def link_spec_cls(cls): ''' The class to use when constructing LinkSpec objects Override this if extending to use a class other than LinkSpec to build link specifications ''' return LinkSpec @classmethod def build_const_args(cls, spec_dict): ''' Build constructor arguments for this Spec class from a dictionary ''' ret = super().build_const_args(spec_dict) if 'datasets' in ret: ret['datasets'] = list(map(cls.dataset_spec_cls().build_spec, ret['datasets'])) if 'groups' in ret: ret['groups'] = list(map(cls.build_spec, ret['groups'])) if 'links' in ret: ret['links'] = list(map(cls.link_spec_cls().build_spec, ret['links'])) return ret ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/spec/write.py0000644000655200065520000002561400000000000016734 0ustar00circlecicircleciimport copy import json import os.path import warnings from abc import ABCMeta, abstractmethod from collections import OrderedDict from datetime import datetime import ruamel.yaml as yaml from .catalog import SpecCatalog from .namespace import SpecNamespace from .spec import GroupSpec, DatasetSpec from ..utils import docval, getargs, popargs class SpecWriter(metaclass=ABCMeta): @abstractmethod def write_spec(self, spec_file_dict, path): pass @abstractmethod def write_namespace(self, namespace, path): pass class YAMLSpecWriter(SpecWriter): @docval({'name': 'outdir', 'type': str, 'doc': 'the path to write the directory to output the namespace and specs too', 'default': '.'}) def __init__(self, **kwargs): self.__outdir = getargs('outdir', kwargs) def __dump_spec(self, specs, stream): specs_plain_dict = json.loads(json.dumps(specs)) yaml_obj = yaml.YAML(typ='safe', pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(specs_plain_dict, stream) def write_spec(self, spec_file_dict, path): out_fullpath = os.path.join(self.__outdir, path) spec_plain_dict = json.loads(json.dumps(spec_file_dict)) sorted_data = self.sort_keys(spec_plain_dict) with open(out_fullpath, 'w') as fd_write: yaml_obj = yaml.YAML(pure=True) yaml_obj.dump(sorted_data, fd_write) def write_namespace(self, namespace, path): """Write the given namespace key-value pairs as YAML to the given path. :param namespace: SpecNamespace holding the key-value pairs that define the namespace :param path: File path to write the namespace to as YAML under the key 'namespaces' """ with open(os.path.join(self.__outdir, path), 'w') as stream: # Convert the date to a string if necessary ns = namespace if 'date' in namespace and isinstance(namespace['date'], datetime): ns = copy.copy(ns) # copy the namespace to avoid side-effects ns['date'] = ns['date'].isoformat() self.__dump_spec({'namespaces': [ns]}, stream) def reorder_yaml(self, path): """ Open a YAML file, load it as python data, sort the data alphabetically, and write it back out to the same path. """ with open(path, 'rb') as fd_read: yaml_obj = yaml.YAML(pure=True) data = yaml_obj.load(fd_read) self.write_spec(data, path) def sort_keys(self, obj): # Represent None as null def my_represent_none(self, data): return self.represent_scalar(u'tag:yaml.org,2002:null', u'null') yaml.representer.RoundTripRepresenter.add_representer(type(None), my_represent_none) order = ['neurodata_type_def', 'neurodata_type_inc', 'data_type_def', 'data_type_inc', 'name', 'default_name', 'dtype', 'target_type', 'dims', 'shape', 'default_value', 'value', 'doc', 'required', 'quantity', 'attributes', 'datasets', 'groups', 'links'] if isinstance(obj, dict): keys = list(obj.keys()) for k in order[::-1]: if k in keys: keys.remove(k) keys.insert(0, k) if 'neurodata_type_def' not in keys and 'name' in keys: keys.remove('name') keys.insert(0, 'name') return yaml.comments.CommentedMap( yaml.compat.ordereddict([(k, self.sort_keys(obj[k])) for k in keys]) ) elif isinstance(obj, list): return [self.sort_keys(v) for v in obj] elif isinstance(obj, tuple): return (self.sort_keys(v) for v in obj) else: return obj class NamespaceBuilder: ''' A class for building namespace and spec files ''' @docval({'name': 'doc', 'type': str, 'doc': 'Description about what the namespace represents'}, {'name': 'name', 'type': str, 'doc': 'Name of the namespace'}, {'name': 'full_name', 'type': str, 'doc': 'Extended full name of the namespace', 'default': None}, {'name': 'version', 'type': (str, tuple, list), 'doc': 'Version number of the namespace', 'default': None}, {'name': 'author', 'type': (str, list), 'doc': 'Author or list of authors.', 'default': None}, {'name': 'contact', 'type': (str, list), 'doc': 'List of emails. Ordering should be the same as for author', 'default': None}, {'name': 'date', 'type': (datetime, str), 'doc': "Date last modified or released. Formatting is %Y-%m-%d %H:%M:%S, e.g, 2017-04-25 17:14:13", 'default': None}, {'name': 'namespace_cls', 'type': type, 'doc': 'the SpecNamespace type', 'default': SpecNamespace}) def __init__(self, **kwargs): ns_cls = popargs('namespace_cls', kwargs) if kwargs['version'] is None: # version is required on write as of HDMF 1.5. this check should prevent the writing of namespace files # without a verison raise ValueError("Namespace '%s' missing key 'version'. Please specify a version for the extension." % kwargs['name']) self.__ns_args = copy.deepcopy(kwargs) self.__namespaces = OrderedDict() self.__sources = OrderedDict() self.__catalog = SpecCatalog() self.__dt_key = ns_cls.types_key() @docval({'name': 'source', 'type': str, 'doc': 'the path to write the spec to'}, {'name': 'spec', 'type': (GroupSpec, DatasetSpec), 'doc': 'the Spec to add'}) def add_spec(self, **kwargs): ''' Add a Spec to the namespace ''' source, spec = getargs('source', 'spec', kwargs) self.__catalog.auto_register(spec, source) self.add_source(source) self.__sources[source].setdefault(self.__dt_key, list()).append(spec) @docval({'name': 'source', 'type': str, 'doc': 'the path to write the spec to'}, {'name': 'doc', 'type': str, 'doc': 'additional documentation for the source file', 'default': None}, {'name': 'title', 'type': str, 'doc': 'optional heading to be used for the source', 'default': None}) def add_source(self, **kwargs): ''' Add a source file to the namespace ''' source, doc, title = getargs('source', 'doc', 'title', kwargs) if '/' in source or source[0] == '.': raise ValueError('source must be a base file') source_dict = {'source': source} self.__sources.setdefault(source, source_dict) # Update the doc and title if given if doc is not None: self.__sources[source]['doc'] = doc if title is not None: self.__sources[source]['title'] = doc @docval({'name': 'data_type', 'type': str, 'doc': 'the data type to include'}, {'name': 'source', 'type': str, 'doc': 'the source file to include the type from', 'default': None}, {'name': 'namespace', 'type': str, 'doc': 'the namespace from which to include the data type', 'default': None}) def include_type(self, **kwargs): ''' Include a data type from an existing namespace or source ''' dt, src, ns = getargs('data_type', 'source', 'namespace', kwargs) if src is not None: self.add_source(src) self.__sources[src].setdefault(self.__dt_key, list()).append(dt) elif ns is not None: self.include_namespace(ns) self.__namespaces[ns].setdefault(self.__dt_key, list()).append(dt) else: raise ValueError("must specify 'source' or 'namespace' when including type") @docval({'name': 'namespace', 'type': str, 'doc': 'the namespace to include'}) def include_namespace(self, **kwargs): ''' Include an entire namespace ''' namespace = getargs('namespace', kwargs) self.__namespaces.setdefault(namespace, {'namespace': namespace}) @docval({'name': 'path', 'type': str, 'doc': 'the path to write the spec to'}, {'name': 'outdir', 'type': str, 'doc': 'the path to write the directory to output the namespace and specs too', 'default': '.'}, {'name': 'writer', 'type': SpecWriter, 'doc': 'the SpecWriter to use to write the namespace', 'default': None}) def export(self, **kwargs): ''' Export the namespace to the given path. All new specification source files will be written in the same directory as the given path. ''' ns_path, writer = getargs('path', 'writer', kwargs) if writer is None: writer = YAMLSpecWriter(outdir=getargs('outdir', kwargs)) ns_args = copy.copy(self.__ns_args) ns_args['schema'] = list() for ns, info in self.__namespaces.items(): ns_args['schema'].append(info) for path, info in self.__sources.items(): out = SpecFileBuilder() dts = list() for spec in info[self.__dt_key]: if isinstance(spec, str): dts.append(spec) else: out.add_spec(spec) item = {'source': path} if 'doc' in info: item['doc'] = info['doc'] if 'title' in info: item['title'] = info['title'] if out and dts: raise ValueError('cannot include from source if writing to source') elif dts: item[self.__dt_key] = dts elif out: writer.write_spec(out, path) ns_args['schema'].append(item) namespace = SpecNamespace.build_namespace(**ns_args) writer.write_namespace(namespace, ns_path) @property def name(self): return self.__ns_args['name'] class SpecFileBuilder(dict): @docval({'name': 'spec', 'type': (GroupSpec, DatasetSpec), 'doc': 'the Spec to add'}) def add_spec(self, **kwargs): spec = getargs('spec', kwargs) if isinstance(spec, GroupSpec): self.setdefault('groups', list()).append(spec) elif isinstance(spec, DatasetSpec): self.setdefault('datasets', list()).append(spec) def export_spec(ns_builder, new_data_types, output_dir): """ Create YAML specification files for a new namespace and extensions with the given data type specs. Args: ns_builder - NamespaceBuilder instance used to build the namespace and extension new_data_types - Iterable of specs that represent new data types to be added """ if len(new_data_types) == 0: warnings.warn('No data types specified. Exiting.') return ns_path = ns_builder.name + '.namespace.yaml' ext_path = ns_builder.name + '.extensions.yaml' for data_type in new_data_types: ns_builder.add_spec(ext_path, data_type) ns_builder.export(ns_path, outdir=output_dir) ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1627603655.180627 hdmf-3.1.1/src/hdmf/testing/0000755000655200065520000000000000000000000015743 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/testing/__init__.py0000644000655200065520000000016100000000000020052 0ustar00circlecicirclecifrom .testcase import TestCase, H5RoundTripMixin # noqa: F401 from .utils import remove_test_file # noqa: F401 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/testing/testcase.py0000644000655200065520000002632700000000000020142 0ustar00circlecicircleciimport h5py import numpy as np import os import re import unittest from abc import ABCMeta, abstractmethod from .utils import remove_test_file from ..backends.hdf5 import HDF5IO from ..build import Builder from ..common import validate as common_validate, get_manager from ..container import AbstractContainer, Container, Data from ..query import HDMFDataset class TestCase(unittest.TestCase): """ Extension of unittest's TestCase to add useful functions for unit testing in HDMF. """ def assertRaisesWith(self, exc_type, exc_msg, *args, **kwargs): """ Asserts the given invocation raises the expected exception. This is similar to unittest's assertRaises and assertRaisesRegex, but checks for an exact match. """ return self.assertRaisesRegex(exc_type, '^%s$' % re.escape(exc_msg), *args, **kwargs) def assertWarnsWith(self, warn_type, exc_msg, *args, **kwargs): """ Asserts the given invocation raises the expected warning. This is similar to unittest's assertWarns and assertWarnsRegex, but checks for an exact match. """ return self.assertWarnsRegex(warn_type, '^%s$' % re.escape(exc_msg), *args, **kwargs) def assertContainerEqual(self, container1, container2, ignore_name=False, ignore_hdmf_attrs=False): """ Asserts that the two AbstractContainers have equal contents. This applies to both Container and Data types. ignore_name - whether to ignore testing equality of name of the top-level container ignore_hdmf_attrs - whether to ignore testing equality of HDMF container attributes, such as container_source and object_id """ self.assertTrue(isinstance(container1, AbstractContainer)) self.assertTrue(isinstance(container2, AbstractContainer)) type1 = type(container1) type2 = type(container2) self.assertEqual(type1, type2) if not ignore_name: self.assertEqual(container1.name, container2.name) if not ignore_hdmf_attrs: self.assertEqual(container1.container_source, container2.container_source) self.assertEqual(container1.object_id, container2.object_id) # NOTE: parent is not tested because it can lead to infinite loops if isinstance(container1, Container): self.assertEqual(len(container1.children), len(container2.children)) # do not actually check the children values here. all children *should* also be fields, which is checked below. # this is in case non-field children are added to one and not the other for field in getattr(container1, type1._fieldsname): with self.subTest(field=field, container_type=type1.__name__): f1 = getattr(container1, field) f2 = getattr(container2, field) self._assert_field_equal(f1, f2, ignore_hdmf_attrs=ignore_hdmf_attrs) def _assert_field_equal(self, f1, f2, ignore_hdmf_attrs=False): if (isinstance(f1, (tuple, list, np.ndarray, h5py.Dataset)) or isinstance(f2, (tuple, list, np.ndarray, h5py.Dataset))): self._assert_array_equal(f1, f2, ignore_hdmf_attrs=ignore_hdmf_attrs) elif isinstance(f1, dict) and len(f1) and isinstance(f1.values()[0], Container): self.assertIsInstance(f2, dict) f1_keys = set(f1.keys()) f2_keys = set(f2.keys()) self.assertSetEqual(f1_keys, f2_keys) for k in f1_keys: with self.subTest(module_name=k): self.assertContainerEqual(f1[k], f2[k], ignore_hdmf_attrs=ignore_hdmf_attrs) elif isinstance(f1, Container): self.assertContainerEqual(f1, f2, ignore_hdmf_attrs=ignore_hdmf_attrs) elif isinstance(f1, Data): self._assert_data_equal(f1, f2, ignore_hdmf_attrs=ignore_hdmf_attrs) elif isinstance(f1, (float, np.floating)): np.testing.assert_allclose(f1, f2) else: self.assertEqual(f1, f2) def _assert_data_equal(self, data1, data2, ignore_hdmf_attrs=False): self.assertTrue(isinstance(data1, Data)) self.assertTrue(isinstance(data2, Data)) self.assertEqual(len(data1), len(data2)) self._assert_array_equal(data1.data, data2.data, ignore_hdmf_attrs=ignore_hdmf_attrs) self.assertContainerEqual(data1, data2, ignore_hdmf_attrs=ignore_hdmf_attrs) def _assert_array_equal(self, arr1, arr2, ignore_hdmf_attrs=False): if isinstance(arr1, (h5py.Dataset, HDMFDataset)): arr1 = arr1[()] if isinstance(arr2, (h5py.Dataset, HDMFDataset)): arr2 = arr2[()] if not isinstance(arr1, (tuple, list, np.ndarray)) and not isinstance(arr2, (tuple, list, np.ndarray)): if isinstance(arr1, (float, np.floating)): np.testing.assert_allclose(arr1, arr2) else: if isinstance(arr1, bytes): arr1 = arr1.decode('utf-8') if isinstance(arr2, bytes): arr2 = arr2.decode('utf-8') self.assertEqual(arr1, arr2) # scalar else: self.assertEqual(len(arr1), len(arr2)) if isinstance(arr1, np.ndarray) and len(arr1.dtype) > 1: # compound type arr1 = arr1.tolist() if isinstance(arr2, np.ndarray) and len(arr2.dtype) > 1: # compound type arr2 = arr2.tolist() if isinstance(arr1, np.ndarray) and isinstance(arr2, np.ndarray): if np.issubdtype(arr1.dtype, np.number): np.testing.assert_allclose(arr1, arr2) else: np.testing.assert_array_equal(arr1, arr2) else: for sub1, sub2 in zip(arr1, arr2): if isinstance(sub1, Container): self.assertContainerEqual(sub1, sub2, ignore_hdmf_attrs=ignore_hdmf_attrs) elif isinstance(sub1, Data): self._assert_data_equal(sub1, sub2, ignore_hdmf_attrs=ignore_hdmf_attrs) else: self._assert_array_equal(sub1, sub2, ignore_hdmf_attrs=ignore_hdmf_attrs) def assertBuilderEqual(self, builder1, builder2, check_path=True, check_source=True): """Test whether two builders are equal. Like assertDictEqual but also checks type, name, path, and source. """ self.assertTrue(isinstance(builder1, Builder)) self.assertTrue(isinstance(builder2, Builder)) self.assertEqual(type(builder1), type(builder2)) self.assertEqual(builder1.name, builder2.name) if check_path: self.assertEqual(builder1.path, builder2.path) if check_source: self.assertEqual(builder1.source, builder2.source) self.assertDictEqual(builder1, builder2) class H5RoundTripMixin(metaclass=ABCMeta): """ Mixin class for methods to run a roundtrip test writing a container to and reading the container from an HDF5 file. The setUp, test_roundtrip, and tearDown methods will be run by unittest. The abstract method setUpContainer needs to be implemented by classes that include this mixin. Example:: class TestMyContainerRoundTrip(H5RoundTripMixin, TestCase): def setUpContainer(self): # return the Container to read/write NOTE: This class is a mix-in and not a subclass of TestCase so that unittest does not discover it, try to run it, and skip it. """ def setUp(self): self.__manager = get_manager() self.container = self.setUpContainer() self.container_type = self.container.__class__.__name__ self.filename = 'test_%s.h5' % self.container_type self.export_filename = 'test_export_%s.h5' % self.container_type self.writer = None self.reader = None self.export_reader = None def tearDown(self): if self.writer is not None: self.writer.close() if self.reader is not None: self.reader.close() if self.export_reader is not None: self.export_reader.close() remove_test_file(self.filename) remove_test_file(self.export_filename) @abstractmethod def setUpContainer(self): """Return the Container to read/write.""" raise NotImplementedError('Cannot run test unless setUpContainer is implemented') def test_roundtrip(self): """Test whether the container read from a written file is the same as the original file.""" read_container = self.roundtripContainer() self._test_roundtrip(read_container, export=False) def test_roundtrip_export(self): """Test whether the container read from a written and then exported file is the same as the original file.""" read_container = self.roundtripExportContainer() self._test_roundtrip(read_container, export=True) def _test_roundtrip(self, read_container, export=False): self.assertIsNotNone(str(self.container)) # added as a test to make sure printing works self.assertIsNotNone(str(read_container)) # make sure we get a completely new object self.assertNotEqual(id(self.container), id(read_container)) # the name of the root container of a file is always 'root' (see h5tools.py ROOT_NAME) # thus, ignore the name of the container when comparing original container vs read container if not export: self.assertContainerEqual(read_container, self.container, ignore_name=True) else: self.assertContainerEqual(read_container, self.container, ignore_name=True, ignore_hdmf_attrs=True) self.validate(read_container._experimental) def roundtripContainer(self, cache_spec=False): """Write the container to an HDF5 file, read the container from the file, and return it.""" with HDF5IO(self.filename, manager=get_manager(), mode='w') as write_io: write_io.write(self.container, cache_spec=cache_spec) self.reader = HDF5IO(self.filename, manager=get_manager(), mode='r') return self.reader.read() def roundtripExportContainer(self, cache_spec=False): """Write the container to an HDF5 file, read it, export it to a new file, read that file, and return it.""" self.roundtripContainer(cache_spec=cache_spec) HDF5IO.export_io( src_io=self.reader, path=self.export_filename, cache_spec=cache_spec, ) self.export_reader = HDF5IO(self.export_filename, manager=get_manager(), mode='r') return self.export_reader.read() def validate(self, experimental=False): """Validate the written and exported files, if they exist.""" if os.path.exists(self.filename): with HDF5IO(self.filename, manager=get_manager(), mode='r') as io: errors = common_validate(io, experimental=experimental) if errors: for err in errors: raise Exception(err) if os.path.exists(self.export_filename): with HDF5IO(self.filename, manager=get_manager(), mode='r') as io: errors = common_validate(io, experimental=experimental) if errors: for err in errors: raise Exception(err) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/testing/utils.py0000644000655200065520000000070300000000000017455 0ustar00circlecicircleciimport os def remove_test_file(path): """A helper function for removing intermediate test files This checks if the environment variable CLEAN_HDMF has been set to False before removing the file. If CLEAN_HDMF is set to False, it does not remove the file. """ clean_flag_set = os.getenv('CLEAN_HDMF', True) not in ('False', 'false', 'FALSE', '0', 0, False) if os.path.exists(path) and clean_flag_set: os.remove(path) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/testing/validate_spec.py0000755000655200065520000000340200000000000021122 0ustar00circlecicircleciimport json import os from argparse import ArgumentParser from glob import glob import jsonschema import ruamel.yaml as yaml def validate_spec(fpath_spec, fpath_schema): """ Validate a yaml specification file against the json schema file that defines the specification language. Can be used to validate changes to the NWB and HDMF core schemas, as well as any extensions to either. :param fpath_spec: path-like :param fpath_schema: path-like """ schemaAbs = 'file://' + os.path.abspath(fpath_schema) f_schema = open(fpath_schema, 'r') schema = json.load(f_schema) class FixResolver(jsonschema.RefResolver): def __init__(self): jsonschema.RefResolver.__init__(self, base_uri=schemaAbs, referrer=None) self.store[schemaAbs] = schema new_resolver = FixResolver() f_nwb = open(fpath_spec, 'r') instance = yaml.safe_load(f_nwb) jsonschema.validate(instance, schema, resolver=new_resolver) def main(): parser = ArgumentParser(description="Validate an HDMF/NWB specification") parser.add_argument("paths", type=str, nargs='+', help="yaml file paths") parser.add_argument("-m", "--metaschema", type=str, help=".json.schema file used to validate yaml files") args = parser.parse_args() for path in args.paths: if os.path.isfile(path): validate_spec(path, args.metaschema) elif os.path.isdir(path): for ipath in glob(os.path.join(path, '*.yaml')): validate_spec(ipath, args.metaschema) else: raise ValueError('path must be a valid file or directory') if __name__ == "__main__": main() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/utils.py0000644000655200065520000012633100000000000016006 0ustar00circlecicircleciimport collections import copy as _copy import types import warnings from abc import ABCMeta from enum import Enum import h5py import numpy as np __macros = { 'array_data': [np.ndarray, list, tuple, h5py.Dataset], 'scalar_data': [str, int, float, bytes, bool], 'data': [] } # code to signify how to handle positional arguments in docval AllowPositional = Enum('AllowPositional', 'ALLOWED WARNING ERROR') __supported_bool_types = (bool, np.bool_) __supported_uint_types = (np.uint8, np.uint16, np.uint32, np.uint64) __supported_int_types = (int, np.int8, np.int16, np.int32, np.int64) __supported_float_types = [float, np.float16, np.float32, np.float64] if hasattr(np, "float128"): # pragma: no cover __supported_float_types.append(np.float128) if hasattr(np, "longdouble"): # pragma: no cover # on windows python<=3.5, h5py floats resolve float64s as either np.float64 or np.longdouble # non-deterministically. a future version of h5py will fix this. see #112 __supported_float_types.append(np.longdouble) __supported_float_types = tuple(__supported_float_types) __allowed_enum_types = (__supported_bool_types + __supported_uint_types + __supported_int_types + __supported_float_types + (str,)) def docval_macro(macro): """Class decorator to add the class to a list of types associated with the key macro in the __macros dict """ def _dec(cls): if macro not in __macros: __macros[macro] = list() __macros[macro].append(cls) return cls return _dec def get_docval_macro(key=None): """ Return a deepcopy of the docval macros, i.e., strings that represent a customizable list of types for use in docval. :param key: Name of the macro. If key=None, then a dictionary of all macros is returned. Otherwise, a tuple of the types associated with the key is returned. """ if key is None: return _copy.deepcopy(__macros) else: return tuple(__macros[key]) def __type_okay(value, argtype, allow_none=False): """Check a value against a type The difference between this function and :py:func:`isinstance` is that it allows specifying a type as a string. Furthermore, strings allow for specifying more general types, such as a simple numeric type (i.e. ``argtype``="num"). Args: value (any): the value to check argtype (type, str): the type to check for allow_none (bool): whether or not to allow None as a valid value Returns: bool: True if value is a valid instance of argtype """ if value is None: return allow_none if isinstance(argtype, str): if argtype in __macros: return __type_okay(value, __macros[argtype], allow_none=allow_none) elif argtype == 'uint': return __is_uint(value) elif argtype == 'int': return __is_int(value) elif argtype == 'float': return __is_float(value) elif argtype == 'bool': return __is_bool(value) return argtype in [cls.__name__ for cls in value.__class__.__mro__] elif isinstance(argtype, type): if argtype is int: return __is_int(value) elif argtype is float: return __is_float(value) elif argtype is bool: return __is_bool(value) return isinstance(value, argtype) elif isinstance(argtype, tuple) or isinstance(argtype, list): return any(__type_okay(value, i) for i in argtype) else: # argtype is None return True def __shape_okay_multi(value, argshape): if type(argshape[0]) in (tuple, list): # if multiple shapes are present return any(__shape_okay(value, a) for a in argshape) else: return __shape_okay(value, argshape) def __shape_okay(value, argshape): valshape = get_data_shape(value) if not len(valshape) == len(argshape): return False for a, b in zip(valshape, argshape): if b not in (a, None): return False return True def __is_uint(value): return isinstance(value, __supported_uint_types) def __is_int(value): return isinstance(value, __supported_int_types) def __is_float(value): return isinstance(value, __supported_float_types) def __is_bool(value): return isinstance(value, __supported_bool_types) def __format_type(argtype): if isinstance(argtype, str): return argtype elif isinstance(argtype, type): return argtype.__name__ elif isinstance(argtype, tuple) or isinstance(argtype, list): types = [__format_type(i) for i in argtype] if len(types) > 1: return "%s or %s" % (", ".join(types[:-1]), types[-1]) else: return types[0] elif argtype is None: return "any type" else: raise ValueError("argtype must be a type, str, list, or tuple") def __check_enum(argval, arg): """ Helper function to check whether the given argument value validates against the enum specification. :param argval: argument value passed to the function/method :param arg: argument validator - the specification dictionary for this argument :return: None if the value validates successfully, error message if the value does not. """ if argval not in arg['enum']: return "forbidden value for '{}' (got {}, expected {})".format(arg['name'], __fmt_str_quotes(argval), arg['enum']) def __fmt_str_quotes(x): """Return a string or list of strings where the input string or list of strings have single quotes around strings""" if isinstance(x, (list, tuple)): return '{}'.format(x) if isinstance(x, str): return "'%s'" % x return str(x) def __parse_args(validator, args, kwargs, enforce_type=True, enforce_shape=True, allow_extra=False, # noqa: C901 allow_positional=AllowPositional.ALLOWED): """ Internal helper function used by the docval decorator to parse and validate function arguments :param validator: List of dicts from docval with the description of the arguments :param args: List of the values of positional arguments supplied by the caller :param kwargs: Dict keyword arguments supplied by the caller where keys are the argument name and values are the argument value. :param enforce_type: Boolean indicating whether the type of arguments should be enforced :param enforce_shape: Boolean indicating whether the dimensions of array arguments should be enforced if possible. :param allow_extra: Boolean indicating whether extra keyword arguments are allowed (if False and extra keyword arguments are specified, then an error is raised). :param allow_positional: integer code indicating whether positional arguments are allowed: AllowPositional.ALLOWED: positional arguments are allowed AllowPositional.WARNING: return warning if positional arguments are supplied AllowPositional.ERROR: return error if positional arguments are supplied :return: Dict with: * 'args' : Dict all arguments where keys are the names and values are the values of the arguments. * 'errors' : List of string with error messages """ ret = dict() syntax_errors = list() type_errors = list() value_errors = list() future_warnings = list() argsi = 0 extras = dict() # has to be initialized to empty here, to avoid spurious errors reported upon early raises try: # check for duplicates in docval names = [x['name'] for x in validator] duplicated = [item for item, count in collections.Counter(names).items() if count > 1] if duplicated: raise ValueError( 'The following names are duplicated: {}'.format(duplicated)) if allow_extra: # extra keyword arguments are allowed so do not consider them when checking number of args if len(args) > len(validator): raise TypeError( 'Expected at most %d arguments %r, got %d positional' % (len(validator), names, len(args)) ) else: # allow for keyword args if len(args) + len(kwargs) > len(validator): raise TypeError( 'Expected at most %d arguments %r, got %d: %d positional and %d keyword %s' % (len(validator), names, len(args) + len(kwargs), len(args), len(kwargs), sorted(kwargs)) ) if args: if allow_positional == AllowPositional.WARNING: msg = 'Positional arguments are discouraged and may be forbidden in a future release.' future_warnings.append(msg) elif allow_positional == AllowPositional.ERROR: msg = 'Only keyword arguments (e.g., func(argname=value, ...)) are allowed.' syntax_errors.append(msg) # iterate through the docval specification and find a matching value in args / kwargs it = iter(validator) arg = next(it) # process positional arguments of the docval specification (no default value) extras = dict(kwargs) while True: if 'default' in arg: break argname = arg['name'] argval_set = False if argname in kwargs: # if this positional arg is specified by a keyword arg and there are remaining positional args that # have not yet been matched, then it is undetermined what those positional args match to. thus, raise # an error if argsi < len(args): type_errors.append("got multiple values for argument '%s'" % argname) argval = kwargs.get(argname) extras.pop(argname, None) argval_set = True elif argsi < len(args): argval = args[argsi] argval_set = True if not argval_set: type_errors.append("missing argument '%s'" % argname) else: if enforce_type: if not __type_okay(argval, arg['type']): if argval is None: fmt_val = (argname, __format_type(arg['type'])) type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val) else: fmt_val = (argname, type(argval).__name__, __format_type(arg['type'])) type_errors.append("incorrect type for '%s' (got '%s', expected '%s')" % fmt_val) if enforce_shape and 'shape' in arg: valshape = get_data_shape(argval) while valshape is None: if argval is None: break if not hasattr(argval, argname): fmt_val = (argval, argname, arg['shape']) value_errors.append("cannot check shape of object '%s' for argument '%s' " "(expected shape '%s')" % fmt_val) break # unpack, e.g. if TimeSeries is passed for arg 'data', then TimeSeries.data is checked argval = getattr(argval, argname) valshape = get_data_shape(argval) if valshape is not None and not __shape_okay_multi(argval, arg['shape']): fmt_val = (argname, valshape, arg['shape']) value_errors.append("incorrect shape for '%s' (got '%s', expected '%s')" % fmt_val) if 'enum' in arg: err = __check_enum(argval, arg) if err: value_errors.append(err) ret[argname] = argval argsi += 1 arg = next(it) # process arguments of the docval specification with a default value while True: argname = arg['name'] if argname in kwargs: ret[argname] = kwargs.get(argname) extras.pop(argname, None) elif len(args) > argsi: ret[argname] = args[argsi] argsi += 1 else: ret[argname] = _copy.deepcopy(arg['default']) argval = ret[argname] if enforce_type: if not __type_okay(argval, arg['type'], arg['default'] is None): if argval is None and arg['default'] is None: fmt_val = (argname, __format_type(arg['type'])) type_errors.append("None is not allowed for '%s' (expected '%s', not None)" % fmt_val) else: fmt_val = (argname, type(argval).__name__, __format_type(arg['type'])) type_errors.append("incorrect type for '%s' (got '%s', expected '%s')" % fmt_val) if enforce_shape and 'shape' in arg and argval is not None: valshape = get_data_shape(argval) while valshape is None: if argval is None: break if not hasattr(argval, argname): fmt_val = (argval, argname, arg['shape']) value_errors.append("cannot check shape of object '%s' for argument '%s' (expected shape '%s')" % fmt_val) break # unpack, e.g. if TimeSeries is passed for arg 'data', then TimeSeries.data is checked argval = getattr(argval, argname) valshape = get_data_shape(argval) if valshape is not None and not __shape_okay_multi(argval, arg['shape']): fmt_val = (argname, valshape, arg['shape']) value_errors.append("incorrect shape for '%s' (got '%s', expected '%s')" % fmt_val) if 'enum' in arg and argval is not None: err = __check_enum(argval, arg) if err: value_errors.append(err) arg = next(it) except StopIteration: pass except TypeError as e: type_errors.append(str(e)) except ValueError as e: value_errors.append(str(e)) if not allow_extra: for key in extras.keys(): type_errors.append("unrecognized argument: '%s'" % key) else: # TODO: Extras get stripped out if function arguments are composed with fmt_docval_args. # allow_extra needs to be tracked on a function so that fmt_docval_args doesn't strip them out for key in extras.keys(): ret[key] = extras[key] return {'args': ret, 'future_warnings': future_warnings, 'type_errors': type_errors, 'value_errors': value_errors, 'syntax_errors': syntax_errors} docval_idx_name = '__dv_idx__' docval_attr_name = '__docval__' __docval_args_loc = 'args' def get_docval(func, *args): '''Get a copy of docval arguments for a function. If args are supplied, return only docval arguments with value for 'name' key equal to the args ''' func_docval = getattr(func, docval_attr_name, None) if func_docval: if args: docval_idx = getattr(func, docval_idx_name, None) try: return tuple(docval_idx[name] for name in args) except KeyError as ke: raise ValueError('Function %s does not have docval argument %s' % (func.__name__, str(ke))) return tuple(func_docval[__docval_args_loc]) else: if args: raise ValueError('Function %s has no docval arguments' % func.__name__) return tuple() # def docval_wrap(func, is_method=True): # if is_method: # @docval(*get_docval(func)) # def method(self, **kwargs): # # return call_docval_args(func, kwargs) # return method # else: # @docval(*get_docval(func)) # def static_method(**kwargs): # return call_docval_args(func, kwargs) # return method def fmt_docval_args(func, kwargs): ''' Separate positional and keyword arguments Useful for methods that wrap other methods ''' func_docval = getattr(func, docval_attr_name, None) ret_args = list() ret_kwargs = dict() kwargs_copy = _copy.copy(kwargs) if func_docval: for arg in func_docval[__docval_args_loc]: val = kwargs_copy.pop(arg['name'], None) if 'default' in arg: if val is not None: ret_kwargs[arg['name']] = val else: ret_args.append(val) if func_docval['allow_extra']: ret_kwargs.update(kwargs_copy) else: raise ValueError('no docval found on %s' % str(func)) return ret_args, ret_kwargs def call_docval_func(func, kwargs): fargs, fkwargs = fmt_docval_args(func, kwargs) return func(*fargs, **fkwargs) def __resolve_type(t): if t is None: return t if isinstance(t, str): if t in __macros: return tuple(__macros[t]) else: return t elif isinstance(t, type): return t elif isinstance(t, (list, tuple)): ret = list() for i in t: resolved = __resolve_type(i) if isinstance(resolved, tuple): ret.extend(resolved) else: ret.append(resolved) return tuple(ret) else: msg = "argtype must be a type, a str, a list, a tuple, or None - got %s" % type(t) raise ValueError(msg) def __check_enum_argtype(argtype): """Return True/False whether the given argtype or list/tuple of argtypes is a supported docval enum type""" if isinstance(argtype, (list, tuple)): return all(x in __allowed_enum_types for x in argtype) return argtype in __allowed_enum_types def docval(*validator, **options): # noqa: C901 '''A decorator for documenting and enforcing type for instance method arguments. This decorator takes a list of dictionaries that specify the method parameters. These dictionaries are used for enforcing type and building a Sphinx docstring. The first arguments are dictionaries that specify the positional arguments and keyword arguments of the decorated function. These dictionaries must contain the following keys: ``'name'``, ``'type'``, and ``'doc'``. This will define a positional argument. To define a keyword argument, specify a default value using the key ``'default'``. To validate the dimensions of an input array add the optional ``'shape'`` parameter. The decorated method must take ``self`` and ``**kwargs`` as arguments. When using this decorator, the functions :py:func:`getargs` and :py:func:`popargs` can be used for easily extracting arguments from kwargs. The following code example demonstrates the use of this decorator: .. code-block:: python @docval({'name': 'arg1':, 'type': str, 'doc': 'this is the first positional argument'}, {'name': 'arg2':, 'type': int, 'doc': 'this is the second positional argument'}, {'name': 'kwarg1':, 'type': (list, tuple), 'doc': 'this is a keyword argument', 'default': list()}, returns='foo object', rtype='Foo')) def foo(self, **kwargs): arg1, arg2, kwarg1 = getargs('arg1', 'arg2', 'kwarg1', **kwargs) ... :param enforce_type: Enforce types of input parameters (Default=True) :param returns: String describing the return values :param rtype: String describing the data type of the return values :param is_method: True if this is decorating an instance or class method, False otherwise (Default=True) :param enforce_shape: Enforce the dimensions of input arrays (Default=True) :param validator: :py:func:`dict` objects specifying the method parameters :param allow_extra: Allow extra arguments (Default=False) :param allow_positional: Allow positional arguments (Default=True) :param options: additional options for documenting and validating method parameters ''' enforce_type = options.pop('enforce_type', True) enforce_shape = options.pop('enforce_shape', True) returns = options.pop('returns', None) rtype = options.pop('rtype', None) is_method = options.pop('is_method', True) allow_extra = options.pop('allow_extra', False) allow_positional = options.pop('allow_positional', True) def dec(func): _docval = _copy.copy(options) _docval['allow_extra'] = allow_extra _docval['allow_positional'] = allow_positional func.__name__ = _docval.get('func_name', func.__name__) func.__doc__ = _docval.get('doc', func.__doc__) pos = list() kw = list() for a in validator: # catch unsupported keys allowable_terms = ('name', 'doc', 'type', 'shape', 'enum', 'default', 'help') unsupported_terms = set(a.keys()) - set(allowable_terms) if unsupported_terms: raise Exception('docval for {}: keys {} are not supported by docval'.format(a['name'], sorted(unsupported_terms))) # check that arg type is valid try: a['type'] = __resolve_type(a['type']) except Exception as e: msg = "docval for %s: error parsing argument type: %s" % (a['name'], e.args[0]) raise Exception(msg) if 'enum' in a: # check that value for enum key is a list or tuple (cannot have only one allowed value) if not isinstance(a['enum'], (list, tuple)): msg = ('docval for %s: enum value must be a list or tuple (received %s)' % (a['name'], type(a['enum']))) raise Exception(msg) # check that arg type is compatible with enum if not __check_enum_argtype(a['type']): msg = 'docval for {}: enum checking cannot be used with arg type {}'.format(a['name'], a['type']) raise Exception(msg) # check that enum allowed values are allowed by arg type if any([not __type_okay(x, a['type']) for x in a['enum']]): msg = ('docval for {}: enum values are of types not allowed by arg type (got {}, ' 'expected {})'.format(a['name'], [type(x) for x in a['enum']], a['type'])) raise Exception(msg) if 'default' in a: kw.append(a) else: pos.append(a) loc_val = pos + kw _docval[__docval_args_loc] = loc_val def _check_args(args, kwargs): """Parse and check arguments to decorated function. Raise warnings and errors as appropriate.""" # this function was separated from func_call() in order to make stepping through lines of code using pdb # easier parsed = __parse_args( loc_val, args[1:] if is_method else args, kwargs, enforce_type=enforce_type, enforce_shape=enforce_shape, allow_extra=allow_extra, allow_positional=allow_positional ) parse_warnings = parsed.get('future_warnings') if parse_warnings: msg = '%s: %s' % (func.__qualname__, ', '.join(parse_warnings)) warnings.warn(msg, FutureWarning) for error_type, ExceptionType in (('type_errors', TypeError), ('value_errors', ValueError), ('syntax_errors', SyntaxError)): parse_err = parsed.get(error_type) if parse_err: msg = '%s: %s' % (func.__qualname__, ', '.join(parse_err)) raise ExceptionType(msg) return parsed['args'] # this code is intentionally separated to make stepping through lines of code using pdb easier if is_method: def func_call(*args, **kwargs): pargs = _check_args(args, kwargs) return func(args[0], **pargs) else: def func_call(*args, **kwargs): pargs = _check_args(args, kwargs) return func(**pargs) _rtype = rtype if isinstance(rtype, type): _rtype = rtype.__name__ docstring = __googledoc(func, _docval[__docval_args_loc], returns=returns, rtype=_rtype) docval_idx = {a['name']: a for a in _docval[__docval_args_loc]} # cache a name-indexed dictionary of args setattr(func_call, '__doc__', docstring) setattr(func_call, '__name__', func.__name__) setattr(func_call, docval_attr_name, _docval) setattr(func_call, docval_idx_name, docval_idx) setattr(func_call, '__module__', func.__module__) return func_call return dec def __sig_arg(argval): if 'default' in argval: default = argval['default'] if isinstance(default, str): default = "'%s'" % default else: default = str(default) return "%s=%s" % (argval['name'], default) else: return argval['name'] def __builddoc(func, validator, docstring_fmt, arg_fmt, ret_fmt=None, returns=None, rtype=None): '''Generate a Spinxy docstring''' def to_str(argtype): if isinstance(argtype, type): module = argtype.__module__ name = argtype.__name__ if module.startswith("h5py") or module.startswith("pandas") or module.startswith("builtins"): return ":py:class:`~{name}`".format(name=name) else: return ":py:class:`~{module}.{name}`".format(name=name, module=module) return argtype def __sphinx_arg(arg): fmt = dict() fmt['name'] = arg.get('name') fmt['doc'] = arg.get('doc') if isinstance(arg['type'], tuple) or isinstance(arg['type'], list): fmt['type'] = " or ".join(map(to_str, arg['type'])) else: fmt['type'] = to_str(arg['type']) return arg_fmt.format(**fmt) sig = "%s(%s)\n\n" % (func.__name__, ", ".join(map(__sig_arg, validator))) desc = func.__doc__.strip() if func.__doc__ is not None else "" sig += docstring_fmt.format(description=desc, args="\n".join(map(__sphinx_arg, validator))) if not (ret_fmt is None or returns is None or rtype is None): sig += ret_fmt.format(returns=returns, rtype=rtype) return sig def __sphinxdoc(func, validator, returns=None, rtype=None): arg_fmt = (":param {name}: {doc}\n" ":type {name}: {type}") docstring_fmt = ("{description}\n\n" "{args}\n") ret_fmt = (":returns: {returns}\n" ":rtype: {rtype}") return __builddoc(func, validator, docstring_fmt, arg_fmt, ret_fmt=ret_fmt, returns=returns, rtype=rtype) def __googledoc(func, validator, returns=None, rtype=None): arg_fmt = " {name} ({type}): {doc}" docstring_fmt = "{description}\n\n" if len(validator) > 0: docstring_fmt += "Args:\n{args}\n" ret_fmt = ("\nReturns:\n" " {rtype}: {returns}") return __builddoc(func, validator, docstring_fmt, arg_fmt, ret_fmt=ret_fmt, returns=returns, rtype=rtype) def getargs(*argnames): """getargs(*argnames, argdict) Convenience function to retrieve arguments from a dictionary in batch. The last argument should be a dictionary, and the other arguments should be the keys (argument names) for which to retrieve the values. :raises ValueError: if a argument name is not found in the dictionary or there is only one argument passed to this function or the last argument is not a dictionary :return: a single value if there is only one argument, or a list of values corresponding to the given argument names """ if len(argnames) < 2: raise ValueError('Must supply at least one key and a dict') if not isinstance(argnames[-1], dict): raise ValueError('Last argument must be a dict') kwargs = argnames[-1] if len(argnames) == 2: if argnames[0] not in kwargs: raise ValueError("Argument not found in dict: '%s'" % argnames[0]) return kwargs.get(argnames[0]) ret = [] for arg in argnames[:-1]: if arg not in kwargs: raise ValueError("Argument not found in dict: '%s'" % arg) ret.append(kwargs.get(arg)) return ret def popargs(*argnames): """popargs(*argnames, argdict) Convenience function to retrieve and remove arguments from a dictionary in batch. The last argument should be a dictionary, and the other arguments should be the keys (argument names) for which to retrieve the values. :raises ValueError: if a argument name is not found in the dictionary or there is only one argument passed to this function or the last argument is not a dictionary :return: a single value if there is only one argument, or a list of values corresponding to the given argument names """ if len(argnames) < 2: raise ValueError('Must supply at least one key and a dict') if not isinstance(argnames[-1], dict): raise ValueError('Last argument must be a dict') kwargs = argnames[-1] if len(argnames) == 2: try: ret = kwargs.pop(argnames[0]) except KeyError as ke: raise ValueError('Argument not found in dict: %s' % str(ke)) return ret try: ret = [kwargs.pop(arg) for arg in argnames[:-1]] except KeyError as ke: raise ValueError('Argument not found in dict: %s' % str(ke)) return ret class ExtenderMeta(ABCMeta): """A metaclass that will extend the base class initialization routine by executing additional functions defined in classes that use this metaclass In general, this class should only be used by core developers. """ __preinit = '__preinit' @classmethod def pre_init(cls, func): setattr(func, cls.__preinit, True) return classmethod(func) __postinit = '__postinit' @classmethod def post_init(cls, func): '''A decorator for defining a routine to run after creation of a type object. An example use of this method would be to define a classmethod that gathers any defined methods or attributes after the base Python type construction (i.e. after :py:func:`type` has been called) ''' setattr(func, cls.__postinit, True) return classmethod(func) def __init__(cls, name, bases, classdict): it = (getattr(cls, n) for n in dir(cls)) it = (a for a in it if hasattr(a, cls.__preinit)) for func in it: func(name, bases, classdict) super().__init__(name, bases, classdict) it = (getattr(cls, n) for n in dir(cls)) it = (a for a in it if hasattr(a, cls.__postinit)) for func in it: func(name, bases, classdict) def get_data_shape(data, strict_no_data_load=False): """ Helper function used to determine the shape of the given array. In order to determine the shape of nested tuples, lists, and sets, this function recursively inspects elements along the dimensions, assuming that the data has a regular, rectangular shape. In the case of out-of-core iterators, this means that the first item along each dimension would potentially be loaded into memory. Set strict_no_data_load=True to enforce that this does not happen, at the cost that we may not be able to determine the shape of the array. :param data: Array for which we should determine the shape. :type data: List, numpy.ndarray, DataChunkIterator, any object that support __len__ or .shape. :param strict_no_data_load: If True and data is an out-of-core iterator, None may be returned. If False (default), the first element of data may be loaded into memory. :return: Tuple of ints indicating the size of known dimensions. Dimensions for which the size is unknown will be set to None. """ def __get_shape_helper(local_data): shape = list() if hasattr(local_data, '__len__'): shape.append(len(local_data)) if len(local_data): el = next(iter(local_data)) if not isinstance(el, (str, bytes)): shape.extend(__get_shape_helper(el)) return tuple(shape) # NOTE: data.maxshape will fail on empty h5py.Dataset without shape or maxshape. this will be fixed in h5py 3.0 if hasattr(data, 'maxshape'): return data.maxshape if hasattr(data, 'shape'): return data.shape if isinstance(data, dict): return None if hasattr(data, '__len__') and not isinstance(data, (str, bytes)): if not strict_no_data_load or isinstance(data, (list, tuple, set)): return __get_shape_helper(data) return None def pystr(s): """ Convert a string of characters to Python str object """ if isinstance(s, bytes): return s.decode('utf-8') else: return s def to_uint_array(arr): """ Convert a numpy array or array-like object to a numpy array of unsigned integers with the same dtype itemsize. For example, a list of int32 values is converted to a numpy array with dtype uint32. :raises ValueError: if input array contains values that are not unsigned integers or non-negative integers. """ if not isinstance(arr, np.ndarray): arr = np.array(arr) if np.issubdtype(arr.dtype, np.unsignedinteger): return arr if np.issubdtype(arr.dtype, np.integer): if (arr < 0).any(): raise ValueError('Cannot convert negative integer values to uint.') dt = np.dtype('uint' + str(int(arr.dtype.itemsize*8))) # keep precision return arr.astype(dt) raise ValueError('Cannot convert array of dtype %s to uint.' % arr.dtype) class LabelledDict(dict): """A dict wrapper that allows querying by an attribute of the values and running a callable on removed items. For example, if the key attribute is set as 'name' in __init__, then all objects added to the LabelledDict must have a 'name' attribute and a particular object in the LabelledDict can be accessed using the syntax ['object_name'] if the object.name == 'object_name'. In this way, LabelledDict acts like a set where values can be retrieved using square brackets around the value of the key attribute. An 'add' method makes clear the association between the key attribute of the LabelledDict and the values of the LabelledDict. LabelledDict also supports retrieval of values with the syntax my_dict['attr == val'], which returns a set of objects in the LabelledDict which have an attribute 'attr' with a string value 'val'. If no objects match that condition, a KeyError is raised. Note that if 'attr' equals the key attribute, then the single matching value is returned, not a set. LabelledDict does not support changing items that have already been set. A TypeError will be raised when using __setitem__ on keys that already exist in the dict. The setdefault and update methods are not supported. A TypeError will be raised when these are called. A callable function may be passed to the constructor to be run on an item after adding it to this dict using the __setitem__ and add methods. A callable function may be passed to the constructor to be run on an item after removing it from this dict using the __delitem__ (the del operator), pop, and popitem methods. It will also be run on each removed item when using the clear method. Usage: LabelledDict(label='my_objects', key_attr='name') my_dict[obj.name] = obj my_dict.add(obj) # simpler syntax Example: # MyTestClass is a class with attributes 'prop1' and 'prop2'. MyTestClass.__init__ sets those attributes. ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') obj2 = MyTestClass('d', 'b') ld[obj1.prop1] = obj1 # obj1 is added to the LabelledDict with the key obj1.prop1. Any other key is not allowed. ld.add(obj2) # Simpler 'add' syntax enforces the required relationship ld['a'] # Returns obj1 ld['prop1 == a'] # Also returns obj1 ld['prop2 == b'] # Returns set([obj1, obj2]) - the set of all values v in ld where v.prop2 == 'b' """ @docval({'name': 'label', 'type': str, 'doc': 'the label on this dictionary'}, {'name': 'key_attr', 'type': str, 'doc': 'the attribute name to use as the key', 'default': 'name'}, {'name': 'add_callable', 'type': types.FunctionType, 'doc': 'function to call on an element after adding it to this dict using the add or __setitem__ methods', 'default': None}, {'name': 'remove_callable', 'type': types.FunctionType, 'doc': ('function to call on an element after removing it from this dict using the pop, popitem, clear, ' 'or __delitem__ methods'), 'default': None}) def __init__(self, **kwargs): label, key_attr, add_callable, remove_callable = getargs('label', 'key_attr', 'add_callable', 'remove_callable', kwargs) self.__label = label self.__key_attr = key_attr self.__add_callable = add_callable self.__remove_callable = remove_callable @property def label(self): """Return the label of this LabelledDict""" return self.__label @property def key_attr(self): """Return the attribute used as the key for values in this LabelledDict""" return self.__key_attr def __getitem__(self, args): """Get a value from the LabelledDict with the given key. Supports syntax my_dict['attr == val'], which returns a set of objects in the LabelledDict which have an attribute 'attr' with a string value 'val'. If no objects match that condition, an empty set is returned. Note that if 'attr' equals the key attribute of this LabelledDict, then the single matching value is returned, not a set. """ key = args if '==' in args: key, val = args.split("==") key = key.strip() val = val.strip() # val is a string if not key: raise ValueError("An attribute name is required before '=='.") if not val: raise ValueError("A value is required after '=='.") if key != self.key_attr: ret = set() for item in self.values(): if getattr(item, key, None) == val: ret.add(item) return ret else: return super().__getitem__(val) else: return super().__getitem__(key) def __setitem__(self, key, value): """Set a value in the LabelledDict with the given key. The key must equal value.key_attr. See LabelledDict.add for a simpler syntax since the key is redundant. Raises TypeError is key already exists. Raises ValueError if value does not have attribute key_attr. """ if key in self: raise TypeError("Key '%s' is already in this dict. Cannot reset items in a %s." % (key, self.__class__.__name__)) self.__check_value(value) if key != getattr(value, self.key_attr): raise KeyError("Key '%s' must equal attribute '%s' of '%s'." % (key, self.key_attr, value)) super().__setitem__(key, value) if self.__add_callable: self.__add_callable(value) def add(self, value): """Add a value to the dict with the key value.key_attr. Raises ValueError if value does not have attribute key_attr. """ self.__check_value(value) self.__setitem__(getattr(value, self.key_attr), value) def __check_value(self, value): if not hasattr(value, self.key_attr): raise ValueError("Cannot set value '%s' in %s. Value must have attribute '%s'." % (value, self.__class__.__name__, self.key_attr)) def pop(self, k): """Remove an item that matches the key. If remove_callable was initialized, call that on the returned value.""" ret = super().pop(k) if self.__remove_callable: self.__remove_callable(ret) return ret def popitem(self): """Remove the last added item. If remove_callable was initialized, call that on the returned value. Note: popitem returns a tuple (key, value) but the remove_callable will be called only on the value. Note: in Python 3.5 and earlier, dictionaries are not ordered, so popitem removes an arbitrary item. """ ret = super().popitem() if self.__remove_callable: self.__remove_callable(ret[1]) # execute callable only on dict value return ret def clear(self): """Remove all items. If remove_callable was initialized, call that on each returned value. The order of removal depends on the popitem method. """ while len(self): self.popitem() def __delitem__(self, k): """Remove an item that matches the key. If remove_callable was initialized, call that on the matching value.""" item = self[k] super().__delitem__(k) if self.__remove_callable: self.__remove_callable(item) def setdefault(self, k): """setdefault is not supported. A TypeError will be raised.""" raise TypeError('setdefault is not supported for %s' % self.__class__.__name__) def update(self, other): """update is not supported. A TypeError will be raised.""" raise TypeError('update is not supported for %s' % self.__class__.__name__) @docval_macro('array_data') class StrDataset(h5py.Dataset): """Wrapper to decode strings on reading the dataset""" def __init__(self, dset, encoding, errors='strict'): self.dset = dset if encoding is None: encoding = h5py.h5t.check_string_dtype(dset.dtype).encoding self.encoding = encoding self.errors = errors def __getattr__(self, name): return getattr(self.dset, name) def __repr__(self): return '' % repr(self.dset)[1:-1] def __len__(self): return len(self.dset) def __getitem__(self, args): bytes_arr = self.dset[args] # numpy.char.decode() seems like the obvious thing to use. But it only # accepts numpy string arrays, not object arrays of bytes (which we # return from HDF5 variable-length strings). And the numpy # implementation is not faster than doing it with a loop; in fact, by # not converting the result to a numpy unicode array, the # naive way can be faster! (Comparing with numpy 1.18.4, June 2020) if np.isscalar(bytes_arr): return bytes_arr.decode(self.encoding, self.errors) return np.array([ b.decode(self.encoding, self.errors) for b in bytes_arr.flat ], dtype=object).reshape(bytes_arr.shape) ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1627603655.180627 hdmf-3.1.1/src/hdmf/validate/0000755000655200065520000000000000000000000016057 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/validate/__init__.py0000644000655200065520000000023600000000000020171 0ustar00circlecicirclecifrom . import errors from .errors import * # noqa: F403 from .validator import ValidatorMap, Validator, AttributeValidator, DatasetValidator, GroupValidator ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/validate/errors.py0000644000655200065520000002130300000000000017744 0ustar00circlecicirclecifrom numpy import dtype from ..spec.spec import DtypeHelper from ..utils import docval, getargs __all__ = [ "Error", "DtypeError", "MissingError", "ExpectedArrayError", "ShapeError", "MissingDataType", "IllegalLinkError", "IncorrectDataType", "IncorrectQuantityError" ] class Error: @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, {'name': 'reason', 'type': str, 'doc': 'the reason for the error'}, {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) def __init__(self, **kwargs): self.__name = getargs('name', kwargs) self.__reason = getargs('reason', kwargs) self.__location = getargs('location', kwargs) @property def name(self): return self.__name @property def reason(self): return self.__reason @property def location(self): return self.__location @location.setter def location(self, loc): self.__location = loc def __str__(self): return self.__format_str(self.name, self.location, self.reason) @staticmethod def __format_str(name, location, reason): if location is not None: return "%s (%s): %s" % (name, location, reason) else: return "%s: %s" % (name, reason) def __repr__(self): return self.__str__() def __hash__(self): """Returns the hash value of this Error Note: if the location property is set after creation, the hash value will change. Therefore, it is important to finalize the value of location before getting the hash value. """ return hash(self.__equatable_str()) def __equatable_str(self): """A string representation of the error which can be used to check for equality For a single error, name can end up being different depending on whether it is generated from a base data type spec or from an inner type definition. These errors should still be considered equal because they are caused by the same problem. When a location is provided, we only consider the name of the field and drop the rest of the spec name. However, when a location is not available, then we need to use the fully-provided name. """ if self.location is not None: equatable_name = self.name.split('/')[-1] else: equatable_name = self.name return self.__format_str(equatable_name, self.location, self.reason) def __eq__(self, other): return hash(self) == hash(other) class DtypeError(Error): @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, {'name': 'expected', 'type': (dtype, type, str, list), 'doc': 'the expected dtype'}, {'name': 'received', 'type': (dtype, type, str, list), 'doc': 'the received dtype'}, {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) def __init__(self, **kwargs): name = getargs('name', kwargs) expected = getargs('expected', kwargs) received = getargs('received', kwargs) if isinstance(expected, list): expected = DtypeHelper.simplify_cpd_type(expected) reason = "incorrect type - expected '%s', got '%s'" % (expected, received) loc = getargs('location', kwargs) super().__init__(name, reason, location=loc) class MissingError(Error): @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) def __init__(self, **kwargs): name = getargs('name', kwargs) reason = "argument missing" loc = getargs('location', kwargs) super().__init__(name, reason, location=loc) class MissingDataType(Error): @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, {'name': 'data_type', 'type': str, 'doc': 'the missing data type'}, {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}, {'name': 'missing_dt_name', 'type': str, 'doc': 'the name of the missing data type', 'default': None}) def __init__(self, **kwargs): name, data_type, missing_dt_name = getargs('name', 'data_type', 'missing_dt_name', kwargs) self.__data_type = data_type if missing_dt_name is not None: reason = "missing data type %s (%s)" % (self.__data_type, missing_dt_name) else: reason = "missing data type %s" % self.__data_type loc = getargs('location', kwargs) super().__init__(name, reason, location=loc) @property def data_type(self): return self.__data_type class IncorrectQuantityError(Error): """A validation error indicating that a child group/dataset/link has the incorrect quantity of matching elements""" @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, {'name': 'data_type', 'type': str, 'doc': 'the data type which has the incorrect quantity'}, {'name': 'expected', 'type': (str, int), 'doc': 'the expected quantity'}, {'name': 'received', 'type': (str, int), 'doc': 'the received quantity'}, {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) def __init__(self, **kwargs): name, data_type, expected, received = getargs('name', 'data_type', 'expected', 'received', kwargs) reason = "expected a quantity of %s for data type %s, received %s" % (str(expected), data_type, str(received)) loc = getargs('location', kwargs) super().__init__(name, reason, location=loc) class ExpectedArrayError(Error): @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, {'name': 'expected', 'type': (tuple, list), 'doc': 'the expected shape'}, {'name': 'received', 'type': str, 'doc': 'the received data'}, {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) def __init__(self, **kwargs): name = getargs('name', kwargs) expected = getargs('expected', kwargs) received = getargs('received', kwargs) reason = "incorrect shape - expected an array of shape '%s', got non-array data '%s'" % (expected, received) loc = getargs('location', kwargs) super().__init__(name, reason, location=loc) class ShapeError(Error): @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, {'name': 'expected', 'type': (tuple, list), 'doc': 'the expected shape'}, {'name': 'received', 'type': (tuple, list), 'doc': 'the received shape'}, {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) def __init__(self, **kwargs): name = getargs('name', kwargs) expected = getargs('expected', kwargs) received = getargs('received', kwargs) reason = "incorrect shape - expected '%s', got '%s'" % (expected, received) loc = getargs('location', kwargs) super().__init__(name, reason, location=loc) class IllegalLinkError(Error): """ A validation error for indicating that a link was used where an actual object (i.e. a dataset or a group) must be used """ @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) def __init__(self, **kwargs): name = getargs('name', kwargs) reason = "illegal use of link (linked object will not be validated)" loc = getargs('location', kwargs) super().__init__(name, reason, location=loc) class IncorrectDataType(Error): """ A validation error for indicating that the incorrect data_type (not dtype) was used. """ @docval({'name': 'name', 'type': str, 'doc': 'the name of the component that is erroneous'}, {'name': 'expected', 'type': str, 'doc': 'the expected data_type'}, {'name': 'received', 'type': str, 'doc': 'the received data_type'}, {'name': 'location', 'type': str, 'doc': 'the location of the error', 'default': None}) def __init__(self, **kwargs): name = getargs('name', kwargs) expected = getargs('expected', kwargs) received = getargs('received', kwargs) reason = "incorrect data_type - expected '%s', got '%s'" % (expected, received) loc = getargs('location', kwargs) super().__init__(name, reason, location=loc) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/src/hdmf/validate/validator.py0000644000655200065520000006673600000000000020440 0ustar00circlecicircleciimport re from abc import ABCMeta, abstractmethod from copy import copy from itertools import chain from collections import defaultdict, OrderedDict import numpy as np from .errors import Error, DtypeError, MissingError, MissingDataType, ShapeError, IllegalLinkError, IncorrectDataType from .errors import ExpectedArrayError, IncorrectQuantityError from ..build import GroupBuilder, DatasetBuilder, LinkBuilder, ReferenceBuilder, RegionBuilder from ..build.builders import BaseBuilder from ..spec import Spec, AttributeSpec, GroupSpec, DatasetSpec, RefSpec, LinkSpec from ..spec import SpecNamespace from ..spec.spec import BaseStorageSpec, DtypeHelper from ..utils import docval, getargs, call_docval_func, pystr, get_data_shape from ..query import ReferenceResolver __synonyms = DtypeHelper.primary_dtype_synonyms __additional = { 'float': ['double'], 'int8': ['short', 'int', 'long'], 'short': ['int', 'long'], 'int': ['long'], 'uint8': ['uint16', 'uint32', 'uint64'], 'uint16': ['uint32', 'uint64'], 'uint32': ['uint64'], 'utf': ['ascii'] } # if the spec dtype is a key in __allowable, then all types in __allowable[key] are valid __allowable = dict() for dt, dt_syn in __synonyms.items(): allow = copy(dt_syn) if dt in __additional: for addl in __additional[dt]: allow.extend(__synonyms[addl]) for syn in dt_syn: __allowable[syn] = allow __allowable['numeric'] = set(chain.from_iterable(__allowable[k] for k in __allowable if 'int' in k or 'float' in k)) def check_type(expected, received): ''' *expected* should come from the spec *received* should come from the data ''' if isinstance(expected, list): if len(expected) > len(received): raise ValueError('compound type shorter than expected') for i, exp in enumerate(DtypeHelper.simplify_cpd_type(expected)): rec = received[i] if rec not in __allowable[exp]: return False return True else: if isinstance(received, np.dtype): if received.char == 'O': if 'vlen' in received.metadata: received = received.metadata['vlen'] else: raise ValueError("Unrecognized type: '%s'" % received) received = 'utf' if received is str else 'ascii' elif received.char == 'U': received = 'utf' elif received.char == 'S': received = 'ascii' else: received = received.name elif isinstance(received, type): received = received.__name__ if isinstance(expected, RefSpec): expected = expected.reftype elif isinstance(expected, type): expected = expected.__name__ return received in __allowable[expected] def get_iso8601_regex(): isodate_re = (r'^(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):' r'([0-5][0-9]):([0-5][0-9])(\.[0-9]+)?(Z|[+-](?:2[0-3]|[01][0-9]):[0-5][0-9])?$') return re.compile(isodate_re) _iso_re = get_iso8601_regex() def _check_isodatetime(s, default=None): try: if _iso_re.match(pystr(s)) is not None: return 'isodatetime' except Exception: pass return default class EmptyArrayError(Exception): pass def get_type(data): if isinstance(data, str): return _check_isodatetime(data, 'utf') elif isinstance(data, bytes): return _check_isodatetime(data, 'ascii') elif isinstance(data, RegionBuilder): return 'region' elif isinstance(data, ReferenceBuilder): return 'object' elif isinstance(data, ReferenceResolver): return data.dtype elif isinstance(data, np.ndarray): if data.size == 0: raise EmptyArrayError() return get_type(data[0]) elif isinstance(data, np.bool_): return 'bool' if not hasattr(data, '__len__'): return type(data).__name__ else: if hasattr(data, 'dtype'): if isinstance(data.dtype, list): return [get_type(data[0][i]) for i in range(len(data.dtype))] if data.dtype.metadata is not None and data.dtype.metadata.get('vlen') is not None: return get_type(data[0]) return data.dtype if len(data) == 0: raise EmptyArrayError() return get_type(data[0]) def check_shape(expected, received): ret = False if expected is None: ret = True else: if isinstance(expected, (list, tuple)): if isinstance(expected[0], (list, tuple)): for sub in expected: if check_shape(sub, received): ret = True break else: if len(expected) > 0 and received is None: ret = False elif len(expected) == len(received): ret = True for e, r in zip(expected, received): if not check_shape(e, r): ret = False break elif isinstance(expected, int): ret = expected == received return ret class ValidatorMap: """A class for keeping track of Validator objects for all data types in a namespace""" @docval({'name': 'namespace', 'type': SpecNamespace, 'doc': 'the namespace to builder map for'}) def __init__(self, **kwargs): ns = getargs('namespace', kwargs) self.__ns = ns tree = defaultdict(list) types = ns.get_registered_types() self.__type_key = ns.get_spec(types[0]).type_key() for dt in types: spec = ns.get_spec(dt) parent = spec.data_type_inc child = spec.data_type_def tree[child] = list() if parent is not None: tree[parent].append(child) for t in tree: self.__rec(tree, t) self.__valid_types = dict() self.__validators = dict() for dt, children in tree.items(): _list = list() for t in children: spec = self.__ns.get_spec(t) if isinstance(spec, GroupSpec): val = GroupValidator(spec, self) else: val = DatasetValidator(spec, self) if t == dt: self.__validators[t] = val _list.append(val) self.__valid_types[dt] = tuple(_list) def __rec(self, tree, node): if not isinstance(tree[node], tuple): sub_types = {node} for child in tree[node]: sub_types.update(self.__rec(tree, child)) tree[node] = tuple(sub_types) return tree[node] @property def namespace(self): return self.__ns @docval({'name': 'spec', 'type': (Spec, str), 'doc': 'the specification to use to validate'}, returns='all valid sub data types for the given spec', rtype=tuple) def valid_types(self, **kwargs): '''Get all valid types for a given data type''' spec = getargs('spec', kwargs) if isinstance(spec, Spec): spec = spec.data_type_def try: return self.__valid_types[spec] except KeyError: raise ValueError("no children for '%s'" % spec) @docval({'name': 'data_type', 'type': (BaseStorageSpec, str), 'doc': 'the data type to get the validator for'}, returns='the validator ``data_type``') def get_validator(self, **kwargs): """Return the validator for a given data type""" dt = getargs('data_type', kwargs) if isinstance(dt, BaseStorageSpec): dt_tmp = dt.data_type_def if dt_tmp is None: dt_tmp = dt.data_type_inc dt = dt_tmp try: return self.__validators[dt] except KeyError: msg = "data type '%s' not found in namespace %s" % (dt, self.__ns.name) raise ValueError(msg) @docval({'name': 'builder', 'type': BaseBuilder, 'doc': 'the builder to validate'}, returns="a list of errors found", rtype=list) def validate(self, **kwargs): """Validate a builder against a Spec ``builder`` must have the attribute used to specifying data type by the namespace used to construct this ValidatorMap. """ builder = getargs('builder', kwargs) dt = builder.attributes.get(self.__type_key) if dt is None: msg = "builder must have data type defined with attribute '%s'" % self.__type_key raise ValueError(msg) validator = self.get_validator(dt) return validator.validate(builder) class Validator(metaclass=ABCMeta): '''A base class for classes that will be used to validate against Spec subclasses''' @docval({'name': 'spec', 'type': Spec, 'doc': 'the specification to use to validate'}, {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) def __init__(self, **kwargs): self.__spec = getargs('spec', kwargs) self.__vmap = getargs('validator_map', kwargs) @property def spec(self): return self.__spec @property def vmap(self): return self.__vmap @abstractmethod @docval({'name': 'value', 'type': None, 'doc': 'either in the form of a value or a Builder'}, returns='a list of Errors', rtype=list) def validate(self, **kwargs): pass @classmethod def get_spec_loc(cls, spec): return spec.path @classmethod def get_builder_loc(cls, builder): stack = list() tmp = builder while tmp is not None and tmp.name != 'root': stack.append(tmp.name) tmp = tmp.parent return "/".join(reversed(stack)) class AttributeValidator(Validator): '''A class for validating values against AttributeSpecs''' @docval({'name': 'spec', 'type': AttributeSpec, 'doc': 'the specification to use to validate'}, {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) @docval({'name': 'value', 'type': None, 'doc': 'the value to validate'}, returns='a list of Errors', rtype=list) def validate(self, **kwargs): value = getargs('value', kwargs) ret = list() spec = self.spec if spec.required and value is None: ret.append(MissingError(self.get_spec_loc(spec))) else: if spec.dtype is None: ret.append(Error(self.get_spec_loc(spec))) elif isinstance(spec.dtype, RefSpec): if not isinstance(value, BaseBuilder): expected = '%s reference' % spec.dtype.reftype try: value_type = get_type(value) ret.append(DtypeError(self.get_spec_loc(spec), expected, value_type)) except EmptyArrayError: # do not validate dtype of empty array. HDMF does not yet set dtype when writing a list/tuple pass else: target_spec = self.vmap.namespace.catalog.get_spec(spec.dtype.target_type) data_type = value.attributes.get(target_spec.type_key()) hierarchy = self.vmap.namespace.catalog.get_hierarchy(data_type) if spec.dtype.target_type not in hierarchy: ret.append(IncorrectDataType(self.get_spec_loc(spec), spec.dtype.target_type, data_type)) else: try: dtype = get_type(value) if not check_type(spec.dtype, dtype): ret.append(DtypeError(self.get_spec_loc(spec), spec.dtype, dtype)) except EmptyArrayError: # do not validate dtype of empty array. HDMF does not yet set dtype when writing a list/tuple pass shape = get_data_shape(value) if not check_shape(spec.shape, shape): if shape is None: ret.append(ExpectedArrayError(self.get_spec_loc(self.spec), self.spec.shape, str(value))) else: ret.append(ShapeError(self.get_spec_loc(spec), spec.shape, shape)) return ret class BaseStorageValidator(Validator): '''A base class for validating against Spec objects that have attributes i.e. BaseStorageSpec''' @docval({'name': 'spec', 'type': BaseStorageSpec, 'doc': 'the specification to use to validate'}, {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) self.__attribute_validators = dict() for attr in self.spec.attributes: self.__attribute_validators[attr.name] = AttributeValidator(attr, self.vmap) @docval({"name": "builder", "type": BaseBuilder, "doc": "the builder to validate"}, returns='a list of Errors', rtype=list) def validate(self, **kwargs): builder = getargs('builder', kwargs) attributes = builder.attributes ret = list() for attr, validator in self.__attribute_validators.items(): attr_val = attributes.get(attr) if attr_val is None: if validator.spec.required: ret.append(MissingError(self.get_spec_loc(validator.spec), location=self.get_builder_loc(builder))) else: errors = validator.validate(attr_val) for err in errors: err.location = self.get_builder_loc(builder) + ".%s" % validator.spec.name ret.extend(errors) return ret class DatasetValidator(BaseStorageValidator): '''A class for validating DatasetBuilders against DatasetSpecs''' @docval({'name': 'spec', 'type': DatasetSpec, 'doc': 'the specification to use to validate'}, {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) @docval({"name": "builder", "type": DatasetBuilder, "doc": "the builder to validate"}, returns='a list of Errors', rtype=list) def validate(self, **kwargs): builder = getargs('builder', kwargs) ret = super().validate(builder) data = builder.data if self.spec.dtype is not None: try: dtype = get_type(data) if not check_type(self.spec.dtype, dtype): ret.append(DtypeError(self.get_spec_loc(self.spec), self.spec.dtype, dtype, location=self.get_builder_loc(builder))) except EmptyArrayError: # do not validate dtype of empty array. HDMF does not yet set dtype when writing a list/tuple pass shape = get_data_shape(data) if not check_shape(self.spec.shape, shape): if shape is None: ret.append(ExpectedArrayError(self.get_spec_loc(self.spec), self.spec.shape, str(data), location=self.get_builder_loc(builder))) else: ret.append(ShapeError(self.get_spec_loc(self.spec), self.spec.shape, shape, location=self.get_builder_loc(builder))) return ret def _resolve_data_type(spec): if isinstance(spec, LinkSpec): return spec.target_type return spec.data_type class GroupValidator(BaseStorageValidator): '''A class for validating GroupBuilders against GroupSpecs''' @docval({'name': 'spec', 'type': GroupSpec, 'doc': 'the specification to use to validate'}, {'name': 'validator_map', 'type': ValidatorMap, 'doc': 'the ValidatorMap to use during validation'}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) @docval({"name": "builder", "type": GroupBuilder, "doc": "the builder to validate"}, # noqa: C901 returns='a list of Errors', rtype=list) def validate(self, **kwargs): # noqa: C901 builder = getargs('builder', kwargs) errors = super().validate(builder) errors.extend(self.__validate_children(builder)) return self._remove_duplicates(errors) def __validate_children(self, parent_builder): """Validates the children of the group builder against the children in the spec. Children are defined as datasets, groups, and links. Validation works by first assigning builder children to spec children in a many-to-one relationship using a SpecMatcher (this matching is non-trivial due to inheritance, which is why it is isolated in a separate class). Once the matching is complete, it is a straightforward procedure for validating the set of matching builders against each child spec. """ spec_children = chain(self.spec.datasets, self.spec.groups, self.spec.links) matcher = SpecMatcher(self.vmap, spec_children) builder_children = chain(parent_builder.datasets.values(), parent_builder.groups.values(), parent_builder.links.values()) matcher.assign_to_specs(builder_children) for child_spec, matched_builders in matcher.spec_matches: yield from self.__validate_presence_and_quantity(child_spec, len(matched_builders), parent_builder) for child_builder in matched_builders: yield from self.__validate_child_builder(child_spec, child_builder, parent_builder) def __validate_presence_and_quantity(self, child_spec, n_builders, parent_builder): """Validate that at least one matching builder exists if the spec is required and that the number of builders agrees with the spec quantity """ if n_builders == 0 and child_spec.required: yield self.__construct_missing_child_error(child_spec, parent_builder) elif self.__incorrect_quantity(n_builders, child_spec): yield self.__construct_incorrect_quantity_error(child_spec, parent_builder, n_builders) def __construct_missing_child_error(self, child_spec, parent_builder): """Returns either a MissingDataType or a MissingError depending on whether or not a specific data type can be resolved from the spec """ data_type = _resolve_data_type(child_spec) builder_loc = self.get_builder_loc(parent_builder) if data_type is not None: name_of_erroneous = self.get_spec_loc(self.spec) return MissingDataType(name_of_erroneous, data_type, location=builder_loc, missing_dt_name=child_spec.name) else: name_of_erroneous = self.get_spec_loc(child_spec) return MissingError(name_of_erroneous, location=builder_loc) @staticmethod def __incorrect_quantity(n_found, spec): """Returns a boolean indicating whether the number of builder elements matches the specified quantity""" if not spec.is_many() and n_found > 1: return True elif isinstance(spec.quantity, int) and n_found != spec.quantity: return True return False def __construct_incorrect_quantity_error(self, child_spec, parent_builder, n_builders): name_of_erroneous = self.get_spec_loc(self.spec) data_type = _resolve_data_type(child_spec) builder_loc = self.get_builder_loc(parent_builder) return IncorrectQuantityError(name_of_erroneous, data_type, expected=child_spec.quantity, received=n_builders, location=builder_loc) def __validate_child_builder(self, child_spec, child_builder, parent_builder): """Validate a child builder against a child spec considering links""" if isinstance(child_builder, LinkBuilder): if self.__cannot_be_link(child_spec): yield self.__construct_illegal_link_error(child_spec, parent_builder) return # do not validate illegally linked objects child_builder = child_builder.builder for child_validator in self.__get_child_validators(child_spec): yield from child_validator.validate(child_builder) def __construct_illegal_link_error(self, child_spec, parent_builder): name_of_erroneous = self.get_spec_loc(child_spec) builder_loc = self.get_builder_loc(parent_builder) return IllegalLinkError(name_of_erroneous, location=builder_loc) @staticmethod def __cannot_be_link(spec): return not isinstance(spec, LinkSpec) and not spec.linkable def __get_child_validators(self, spec): """Returns the appropriate list of validators for a child spec Due to the fact that child specs can both inherit a data type via data_type_inc and also modify the type without defining a new data type via data_type_def, we need to validate against both the spec for the base data type and the spec at the current hierarchy of the data type in case there have been any modifications. If a specific data type can be resolved, a validator for that type is acquired from the ValidatorMap and included in the returned validators. If the spec is a GroupSpec or a DatasetSpec, then a new Validator is created and also returned. If the spec is a LinkSpec, no additional Validator is returned because the LinkSpec cannot add or modify fields and the target_type will be validated by the Validator returned from the ValidatorMap. """ if _resolve_data_type(spec) is not None: yield self.vmap.get_validator(_resolve_data_type(spec)) if isinstance(spec, GroupSpec): yield GroupValidator(spec, self.vmap) elif isinstance(spec, DatasetSpec): yield DatasetValidator(spec, self.vmap) elif isinstance(spec, LinkSpec): return else: msg = "Unable to resolve a validator for spec %s" % spec raise ValueError(msg) @staticmethod def _remove_duplicates(errors): """Return a list of validation errors where duplicates have been removed In some cases a child of a group to be validated against two specs which can redundantly define the same fields/children. If the builder doesn't match the spec, it is possible for duplicate errors to be generated. """ ordered_errors = OrderedDict() for error in errors: ordered_errors[error] = error return list(ordered_errors) class SpecMatches: """A utility class to hold a spec and the builders matched to it""" def __init__(self, spec): self.spec = spec self.builders = list() def add(self, builder): self.builders.append(builder) class SpecMatcher: """Matches a set of builders against a set of specs This class is intended to isolate the task of choosing which spec a builder should be validated against from the task of performing that validation. """ def __init__(self, vmap, specs): self.vmap = vmap self._spec_matches = [SpecMatches(spec) for spec in specs] self._unmatched_builders = SpecMatches(None) @property def unmatched_builders(self): """Returns the builders for which no matching spec was found These builders can be considered superfluous, and will generate a warning in the future. """ return self._unmatched_builders.builders @property def spec_matches(self): """Returns a list of tuples of: (spec, assigned builders)""" return [(sm.spec, sm.builders) for sm in self._spec_matches] def assign_to_specs(self, builders): """Assigns a set of builders against a set of specs (many-to-one) In the case that no matching spec is found, a builder will be added to a list of unmatched builders. """ for builder in builders: spec_match = self._best_matching_spec(builder) if spec_match is None: self._unmatched_builders.add(builder) else: spec_match.add(builder) def _best_matching_spec(self, builder): """Finds the best matching spec for builder The current algorithm is: 1. filter specs which meet the minimum requirements of consistent name and data type 2. if more than one candidate meets the minimum requirements, find the candidates which do not yet have a sufficient number of builders assigned (based on the spec quantity) 3. return the first unsatisfied candidate if any, otherwise return the first candidate Note that the current algorithm will give different results depending on the order of the specs or builders, and also does not consider inheritance hierarchy. Future improvements to this matching algorithm should resolve these discrepancies. """ candidates = self._filter_by_name(self._spec_matches, builder) candidates = self._filter_by_type(candidates, builder) if len(candidates) == 0: return None elif len(candidates) == 1: return candidates[0] else: unsatisfied_candidates = self._filter_by_unsatisfied(candidates) if len(unsatisfied_candidates) == 0: return candidates[0] else: return unsatisfied_candidates[0] def _filter_by_name(self, candidates, builder): """Returns the candidate specs that either have the same name as the builder or do not specify a name. """ def name_is_consistent(spec_matches): spec = spec_matches.spec return spec.name is None or spec.name == builder.name return list(filter(name_is_consistent, candidates)) def _filter_by_type(self, candidates, builder): """Returns the candidate specs which have a data type consistent with the builder's data type. """ def compatible_type(spec_matches): spec = spec_matches.spec if isinstance(spec, LinkSpec): validator = self.vmap.get_validator(spec.target_type) spec = validator.spec if spec.data_type is None: return True valid_validators = self.vmap.valid_types(spec.data_type) valid_types = [v.spec.data_type for v in valid_validators] if isinstance(builder, LinkBuilder): dt = builder.builder.attributes.get(spec.type_key()) else: dt = builder.attributes.get(spec.type_key()) return dt in valid_types return list(filter(compatible_type, candidates)) def _filter_by_unsatisfied(self, candidates): """Returns the candidate specs which are not yet matched against a number of builders which fulfils the quantity for the spec. """ def is_unsatisfied(spec_matches): spec = spec_matches.spec n_match = len(spec_matches.builders) if spec.required and n_match == 0: return True if isinstance(spec.quantity, int) and n_match < spec.quantity: return True return False return list(filter(is_unsatisfied, candidates)) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1766272 hdmf-3.1.1/src/hdmf.egg-info/0000755000655200065520000000000000000000000015760 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603655.0 hdmf-3.1.1/src/hdmf.egg-info/PKG-INFO0000644000655200065520000001745300000000000017067 0ustar00circlecicircleciMetadata-Version: 2.1 Name: hdmf Version: 3.1.1 Summary: A package for standardizing hierarchical object data Home-page: https://github.com/hdmf-dev/hdmf Author: Andrew Tritt Author-email: ajtritt@lbl.gov License: BSD Keywords: python HDF HDF5 cross-platform open-data data-format open-source open-science reproducible-research Platform: UNKNOWN Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: License :: OSI Approved :: BSD License Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: Intended Audience :: Science/Research Classifier: Operating System :: Microsoft :: Windows Classifier: Operating System :: MacOS Classifier: Operating System :: Unix Classifier: Topic :: Scientific/Engineering :: Medical Science Apps. Requires-Python: >=3.7 Description-Content-Type: text/x-rst; charset=UTF-8 ======================================== The Hierarchical Data Modeling Framework ======================================== The Hierarchical Data Modeling Framework, or *HDMF*, is a Python package for working with hierarchical data. It provides APIs for specifying data models, reading and writing data to different storage backends, and representing data with Python object. Documentation of HDMF can be found at https://hdmf.readthedocs.io Latest Release ============== .. image:: https://badge.fury.io/py/hdmf.svg :target: https://badge.fury.io/py/hdmf .. image:: https://anaconda.org/conda-forge/hdmf/badges/version.svg :target: https://anaconda.org/conda-forge/hdmf Build Status ============ .. table:: +---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ | Linux | Windows and macOS | +=====================================================================+==================================================================================================+ | .. image:: https://circleci.com/gh/hdmf-dev/hdmf.svg?style=shield | .. image:: https://dev.azure.com/hdmf-dev/hdmf/_apis/build/status/hdmf-dev.hdmf?branchName=dev | | :target: https://circleci.com/gh/hdmf-dev/hdmf | :target: https://dev.azure.com/hdmf-dev/hdmf/_build/latest?definitionId=1&branchName=dev | +---------------------------------------------------------------------+--------------------------------------------------------------------------------------------------+ **Conda** .. image:: https://circleci.com/gh/conda-forge/hdmf-feedstock.svg?style=shield :target: https://circleci.com/gh/conda-forge/hdmf-feedstock Overall Health ============== .. image:: https://github.com/hdmf-dev/hdmf/workflows/Run%20coverage/badge.svg :target: https://github.com/hdmf-dev/hdmf/actions?query=workflow%3A%22Run+coverage%22 .. image:: https://codecov.io/gh/hdmf-dev/hdmf/branch/dev/graph/badge.svg :target: https://codecov.io/gh/hdmf-dev/hdmf .. image:: https://requires.io/github/hdmf-dev/hdmf/requirements.svg?branch=dev :target: https://requires.io/github/hdmf-dev/hdmf/requirements/?branch=dev :alt: Requirements Status .. image:: https://readthedocs.org/projects/hdmf/badge/?version=latest :target: https://hdmf.readthedocs.io/en/latest/?badge=latest :alt: Documentation Status Installation ============ See the HDMF documentation for details http://hdmf.readthedocs.io/en/latest/getting_started.html#installation Code of Conduct =============== This project and everyone participating in it is governed by our `code of conduct guidelines <.github/CODE_OF_CONDUCT.md>`_. By participating, you are expected to uphold this code. Contributing ============ For details on how to contribute to HDMF see our `contribution guidelines `_. Citing HDMF =========== * **Manuscript:** .. code-block:: bibtex @INPROCEEDINGS{9005648, author={A. J. {Tritt} and O. {Rübel} and B. {Dichter} and R. {Ly} and D. {Kang} and E. F. {Chang} and L. M. {Frank} and K. {Bouchard}}, booktitle={2019 IEEE International Conference on Big Data (Big Data)}, title={HDMF: Hierarchical Data Modeling Framework for Modern Science Data Standards}, year={2019}, volume={}, number={}, pages={165-179}, doi={10.1109/BigData47090.2019.9005648}, note={}} * **RRID:** (Hierarchical Data Modeling Framework, RRID:SCR_021303) LICENSE ======= "hdmf" Copyright (c) 2017-2021, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: (1) Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. (2) Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. (3) Neither the name of the University of California, Lawrence Berkeley National Laboratory, U.S. Dept. of Energy nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. You are under no obligation whatsoever to provide any bug fixes, patches, or upgrades to the features, functionality or performance of the source code ("Enhancements") to anyone; however, if you choose to make your Enhancements available either publicly, or directly to Lawrence Berkeley National Laboratory, without imposing a separate written license agreement for such Enhancements, then you hereby grant the following license: a non-exclusive, royalty-free perpetual license to install, use, modify, prepare derivative works, incorporate into other computer software, distribute, and sublicense such enhancements or derivative works thereof, in binary and source code form. COPYRIGHT ========= "hdmf" Copyright (c) 2017-2021, The Regents of the University of California, through Lawrence Berkeley National Laboratory (subject to receipt of any required approvals from the U.S. Dept. of Energy). All rights reserved. If you have questions about your rights to use or distribute this software, please contact Berkeley Lab's Innovation & Partnerships Office at IPO@lbl.gov. NOTICE. This Software was developed under funding from the U.S. Department of Energy and the U.S. Government consequently retains certain rights. As such, the U.S. Government has been granted for itself and others acting on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the Software to reproduce, distribute copies to the public, prepare derivative works, and perform publicly and display publicly, and to permit other to do so. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603655.0 hdmf-3.1.1/src/hdmf.egg-info/SOURCES.txt0000644000655200065520000001023500000000000017645 0ustar00circlecicircleciLegal.txt MANIFEST.in README.rst license.txt requirements-dev.txt requirements-doc.txt requirements-min.txt requirements.txt setup.cfg setup.py test.py tox.ini versioneer.py src/hdmf/__init__.py src/hdmf/_due.py src/hdmf/_version.py src/hdmf/array.py src/hdmf/container.py src/hdmf/data_utils.py src/hdmf/monitor.py src/hdmf/query.py src/hdmf/region.py src/hdmf/utils.py src/hdmf.egg-info/PKG-INFO src/hdmf.egg-info/SOURCES.txt src/hdmf.egg-info/dependency_links.txt src/hdmf.egg-info/entry_points.txt src/hdmf.egg-info/not-zip-safe src/hdmf.egg-info/requires.txt src/hdmf.egg-info/top_level.txt src/hdmf/backends/__init__.py src/hdmf/backends/io.py src/hdmf/backends/warnings.py src/hdmf/backends/hdf5/__init__.py src/hdmf/backends/hdf5/h5_utils.py src/hdmf/backends/hdf5/h5tools.py src/hdmf/build/__init__.py src/hdmf/build/builders.py src/hdmf/build/classgenerator.py src/hdmf/build/errors.py src/hdmf/build/manager.py src/hdmf/build/map.py src/hdmf/build/objectmapper.py src/hdmf/build/warnings.py src/hdmf/common/__init__.py src/hdmf/common/alignedtable.py src/hdmf/common/hierarchicaltable.py src/hdmf/common/multi.py src/hdmf/common/resources.py src/hdmf/common/sparse.py src/hdmf/common/table.py src/hdmf/common/hdmf-common-schema/common/base.yaml src/hdmf/common/hdmf-common-schema/common/experimental.yaml src/hdmf/common/hdmf-common-schema/common/namespace.yaml src/hdmf/common/hdmf-common-schema/common/resources.yaml src/hdmf/common/hdmf-common-schema/common/sparse.yaml src/hdmf/common/hdmf-common-schema/common/table.yaml src/hdmf/common/io/__init__.py src/hdmf/common/io/alignedtable.py src/hdmf/common/io/multi.py src/hdmf/common/io/resources.py src/hdmf/common/io/table.py src/hdmf/spec/__init__.py src/hdmf/spec/catalog.py src/hdmf/spec/namespace.py src/hdmf/spec/spec.py src/hdmf/spec/write.py src/hdmf/testing/__init__.py src/hdmf/testing/testcase.py src/hdmf/testing/utils.py src/hdmf/testing/validate_spec.py src/hdmf/validate/__init__.py src/hdmf/validate/errors.py src/hdmf/validate/validator.py tests/__init__.py tests/coverage/runCoverage tests/unit/__init__.py tests/unit/test_container.py tests/unit/test_io_hdf5.py tests/unit/test_io_hdf5_h5tools.py tests/unit/test_multicontainerinterface.py tests/unit/test_query.py tests/unit/test_table.py tests/unit/utils.py tests/unit/back_compat_tests/1.0.5.h5 tests/unit/back_compat_tests/__init__.py tests/unit/back_compat_tests/test_1_1_0.py tests/unit/build_tests/__init__.py tests/unit/build_tests/test_builder.py tests/unit/build_tests/test_classgenerator.py tests/unit/build_tests/test_convert_dtype.py tests/unit/build_tests/test_io_manager.py tests/unit/build_tests/test_io_map.py tests/unit/build_tests/test_io_map_data.py tests/unit/build_tests/mapper_tests/__init__.py tests/unit/build_tests/mapper_tests/test_build.py tests/unit/build_tests/mapper_tests/test_build_quantity.py tests/unit/common/__init__.py tests/unit/common/test_alignedtable.py tests/unit/common/test_common.py tests/unit/common/test_common_io.py tests/unit/common/test_generate_table.py tests/unit/common/test_linkedtables.py tests/unit/common/test_multi.py tests/unit/common/test_resources.py tests/unit/common/test_sparse.py tests/unit/common/test_table.py tests/unit/spec_tests/__init__.py tests/unit/spec_tests/test-ext.base.yaml tests/unit/spec_tests/test-ext.namespace.yaml tests/unit/spec_tests/test.base.yaml tests/unit/spec_tests/test.namespace.yaml tests/unit/spec_tests/test_attribute_spec.py tests/unit/spec_tests/test_dataset_spec.py tests/unit/spec_tests/test_dtype_spec.py tests/unit/spec_tests/test_group_spec.py tests/unit/spec_tests/test_link_spec.py tests/unit/spec_tests/test_load_namespace.py tests/unit/spec_tests/test_ref_spec.py tests/unit/spec_tests/test_spec_catalog.py tests/unit/spec_tests/test_spec_write.py tests/unit/utils_test/__init__.py tests/unit/utils_test/test_core_DataChunk.py tests/unit/utils_test/test_core_DataChunkIterator.py tests/unit/utils_test/test_core_DataIO.py tests/unit/utils_test/test_core_ShapeValidator.py tests/unit/utils_test/test_docval.py tests/unit/utils_test/test_labelleddict.py tests/unit/utils_test/test_utils.py tests/unit/validator_tests/__init__.py tests/unit/validator_tests/test_errors.py tests/unit/validator_tests/test_validate.py././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603655.0 hdmf-3.1.1/src/hdmf.egg-info/dependency_links.txt0000644000655200065520000000000100000000000022026 0ustar00circlecicircleci ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603655.0 hdmf-3.1.1/src/hdmf.egg-info/entry_points.txt0000644000655200065520000000011000000000000021246 0ustar00circlecicircleci[console_scripts] validate_hdmf_spec = hdmf.testing.validate_spec:main ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603623.0 hdmf-3.1.1/src/hdmf.egg-info/not-zip-safe0000644000655200065520000000000100000000000020206 0ustar00circlecicircleci ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603655.0 hdmf-3.1.1/src/hdmf.egg-info/requires.txt0000644000655200065520000000016400000000000020361 0ustar00circlecicirclecih5py<4,>=2.10 numpy<1.22,>=1.16 scipy<2,>=1.1 pandas<2,>=1.0.5 ruamel.yaml<1,>=0.16 jsonschema<4,>=2.6.0 setuptools ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603655.0 hdmf-3.1.1/src/hdmf.egg-info/top_level.txt0000644000655200065520000000000500000000000020505 0ustar00circlecicirclecihdmf ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/test.py0000755000655200065520000001214600000000000014121 0ustar00circlecicircleci#!/usr/bin/env python from __future__ import print_function import warnings import re import argparse import logging import os.path import os import sys import traceback import unittest flags = {'hdmf': 1, 'example': 4} TOTAL = 0 FAILURES = 0 ERRORS = 0 class SuccessRecordingResult(unittest.TextTestResult): '''A unittest test result class that stores successful test cases as well as failures and skips. ''' def addSuccess(self, test): if not hasattr(self, 'successes'): self.successes = [test] else: self.successes.append(test) def get_all_cases_run(self): '''Return a list of each test case which failed or succeeded ''' cases = [] if hasattr(self, 'successes'): cases.extend(self.successes) cases.extend([failure[0] for failure in self.failures]) return cases def run_test_suite(directory, description="", verbose=True): global TOTAL, FAILURES, ERRORS logging.info("running %s" % description) directory = os.path.join(os.path.dirname(__file__), directory) runner = unittest.TextTestRunner(verbosity=verbose, resultclass=SuccessRecordingResult) test_result = runner.run(unittest.TestLoader().discover(directory)) TOTAL += test_result.testsRun FAILURES += len(test_result.failures) ERRORS += len(test_result.errors) return test_result def _import_from_file(script): import imp return imp.load_source(os.path.basename(script), script) warning_re = re.compile("Parent module '[a-zA-Z0-9]+' not found while handling absolute import") def run_example_tests(): global TOTAL, FAILURES, ERRORS logging.info('running example tests') examples_scripts = list() for root, dirs, files in os.walk(os.path.join(os.path.dirname(__file__), "docs", "gallery")): for f in files: if f.endswith(".py"): examples_scripts.append(os.path.join(root, f)) TOTAL += len(examples_scripts) for script in examples_scripts: try: logging.info("Executing %s" % script) ws = list() with warnings.catch_warnings(record=True) as tmp: _import_from_file(script) for w in tmp: # ignore RunTimeWarnings about importing if isinstance(w.message, RuntimeWarning) and not warning_re.match(str(w.message)): ws.append(w) for w in ws: warnings.showwarning(w.message, w.category, w.filename, w.lineno, w.line) except Exception: print(traceback.format_exc()) FAILURES += 1 ERRORS += 1 def main(): # setup and parse arguments parser = argparse.ArgumentParser('python test.py [options]') parser.set_defaults(verbosity=1, suites=[]) parser.add_argument('-v', '--verbose', const=2, dest='verbosity', action='store_const', help='run in verbose mode') parser.add_argument('-q', '--quiet', const=0, dest='verbosity', action='store_const', help='run disabling output') parser.add_argument('-u', '--unit', action='append_const', const=flags['hdmf'], dest='suites', help='run unit tests for hdmf package') parser.add_argument('-e', '--example', action='append_const', const=flags['example'], dest='suites', help='run example tests') args = parser.parse_args() if not args.suites: args.suites = list(flags.values()) args.suites.pop(args.suites.index(flags['example'])) # remove example as a suite run by default # set up logger root = logging.getLogger() root.setLevel(logging.INFO) ch = logging.StreamHandler(sys.stdout) ch.setLevel(logging.INFO) formatter = logging.Formatter('======================================================================\n' '%(asctime)s - %(levelname)s - %(message)s') ch.setFormatter(formatter) root.addHandler(ch) warnings.simplefilter('always') # many tests use NamespaceCatalog.add_namespace, which is deprecated, to set up tests. # ignore these warnings for now. warnings.filterwarnings("ignore", category=DeprecationWarning, module="hdmf.spec.namespace", message=("NamespaceCatalog.add_namespace has been deprecated. " "SpecNamespaces should be added with load_namespaces.")) # Run unit tests for hdmf package if flags['hdmf'] in args.suites: run_test_suite("tests/unit", "hdmf unit tests", verbose=args.verbosity) # Run example tests if flags['example'] in args.suites: run_example_tests() final_message = 'Ran %s tests' % TOTAL exitcode = 0 if ERRORS > 0 or FAILURES > 0: exitcode = 1 _list = list() if ERRORS > 0: _list.append('errors=%d' % ERRORS) if FAILURES > 0: _list.append('failures=%d' % FAILURES) final_message = '%s - FAILED (%s)' % (final_message, ','.join(_list)) else: final_message = '%s - OK' % final_message logging.info(final_message) return exitcode if __name__ == "__main__": sys.exit(main()) ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1627603655.180627 hdmf-3.1.1/tests/0000755000655200065520000000000000000000000013723 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/__init__.py0000644000655200065520000000000000000000000016022 0ustar00circlecicircleci././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1627603655.180627 hdmf-3.1.1/tests/coverage/0000755000655200065520000000000000000000000015516 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/coverage/runCoverage0000755000655200065520000000044600000000000017730 0ustar00circlecicircleci#!/bin/ksh # use default coverage name COV=coverage3 cd ../.. echo "" echo "Running Tests with Coverage:" ${COV} run --source=. test.py echo "" echo "Creating HTML output:" ${COV} html -d tests/coverage/htmlcov echo "" echo "Open /coverage/htmlcov/index.html to see results." echo "" ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1627603655.180627 hdmf-3.1.1/tests/unit/0000755000655200065520000000000000000000000014702 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/__init__.py0000644000655200065520000000000000000000000017001 0ustar00circlecicircleci././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1846273 hdmf-3.1.1/tests/unit/back_compat_tests/0000755000655200065520000000000000000000000020367 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/back_compat_tests/1.0.5.h50000644000655200065520000004312000000000000021266 0ustar00circlecicircleciHDF  PF`h6TREEHEAPX bucketsspecifications8HhTREE HEAPXtest_bucket@SNOD Hh`%87X95TREE HEAPX foo_holder@SNOD(Pp X x TREEHEAPX foo1foo2@SNOD0 X x XTREEPHEAPXmy_dataHSNOD8`8! 79] 0 attr2 HSNOD@ ]% [% ` @attr1 %+ 8attr3 ?@4 4Q @ H data_type%, H namespace %-$XTREE"HEAPXX!my_dataH 79] 0 attr2 "HSNOD!8! @attr1 %. 8attr3 ?@4 4Q@ H data_type%/ H namespace %087X9GCOL I am foo1 test_coreFoo I am foo2 test_coreFoo test_core FooBucket test_core FooFile ]{"groups":[{"datasets":[{"dtype":"int","name":"my_data","doc":"an example dataset","attributes":[{"name":"attr2","doc":"an example integer attribute","dtype":"int"}]}],"doc":"A test group specification with a data type","data_type_def":"Foo","attributes":[{"name":"attr1","doc":"an example string attribute","dtype":"text"},{"name":"attr3","doc":"an example float attribute","dtype":"float"}]},{"groups":[{"groups":[{"doc":"the Foos in this bucket","quantity":"*","data_type_inc":"Foo"}],"name":"foo_holder","doc":"A subgroup for Foos"}],"doc":"A test group specification for a data type containing data type","data_type_def":"FooBucket"},{"groups":[{"groups":[{"doc":"One or more FooBuckets","quantity":"+","data_type_inc":"FooBucket"}],"name":"buckets","doc":"Holds the FooBuckets"}],"doc":"A file of Foos contained in FooBuckets","data_type_def":"FooFile"}]} [{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]} I am foo1Foo test_core I am foo2Foo test_core FooBucket test_coreFooFile test_core I am foo1Foo test_core I am foo2Foo test_core FooBucket test_coreFooFile test_core! I am foo1"Foo# test_core$ I am foo2%Foo& test_core' FooBucket( test_core)FooFile* test_core+ I am foo1,Foo- test_core. I am foo2/Foo0 test_core1 FooBucket2 test_core3FooFile4 test_coreHPp H data_type %1 H namespace %2 0 .specloc`%E H data_type%3TREE<HEAPXx9test_core@ 0 .specloc`%H:h<TREE@HEAPX<unversioned@SNOD :H:h<P>p@TREE@CHEAPX @testnamespace8SNOD(>P>p@79]SNODD0B79] H namespace %4`././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/back_compat_tests/__init__.py0000644000655200065520000000000000000000000022466 0ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/back_compat_tests/test_1_1_0.py0000644000655200065520000000361600000000000022605 0ustar00circlecicircleciimport os from shutil import copyfile from hdmf.backends.hdf5.h5tools import HDF5IO from hdmf.testing import TestCase from tests.unit.test_io_hdf5_h5tools import _get_manager from tests.unit.utils import Foo, FooBucket class Test1_1_0(TestCase): def setUp(self): # created using manager in test_io_hdf5_h5tools self.orig_1_0_5 = 'tests/unit/back_compat_tests/1.0.5.h5' self.path_1_0_5 = 'test_1.0.5.h5' copyfile(self.orig_1_0_5, self.path_1_0_5) # note: this may break if the current manager is different from the old manager # better to save a spec file self.manager = _get_manager() def tearDown(self): if os.path.exists(self.path_1_0_5): os.remove(self.path_1_0_5) def test_read_1_0_5(self): '''Test whether we can read files made by hdmf version 1.0.5''' with HDF5IO(self.path_1_0_5, manager=self.manager, mode='r') as io: read_foofile = io.read() self.assertTrue(len(read_foofile.buckets) == 1) self.assertListEqual(read_foofile.buckets['test_bucket'].foos['foo1'].my_data[:].tolist(), [0, 1, 2, 3, 4]) self.assertListEqual(read_foofile.buckets['test_bucket'].foos['foo2'].my_data[:].tolist(), [5, 6, 7, 8, 9]) def test_append_1_0_5(self): '''Test whether we can append to files made by hdmf version 1.0.5''' foo = Foo('foo3', [10, 20, 30, 40, 50], "I am foo3", 17, 3.14) foobucket = FooBucket('foobucket2', [foo]) with HDF5IO(self.path_1_0_5, manager=self.manager, mode='a') as io: read_foofile = io.read() read_foofile.add_bucket(foobucket) io.write(read_foofile) with HDF5IO(self.path_1_0_5, manager=self.manager, mode='r') as io: read_foofile = io.read() self.assertListEqual(read_foofile.buckets['foobucket2'].foos['foo3'].my_data[:].tolist(), foo.my_data) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1846273 hdmf-3.1.1/tests/unit/build_tests/0000755000655200065520000000000000000000000017223 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/__init__.py0000644000655200065520000000000000000000000021322 0ustar00circlecicircleci././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1846273 hdmf-3.1.1/tests/unit/build_tests/mapper_tests/0000755000655200065520000000000000000000000021731 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/mapper_tests/__init__.py0000644000655200065520000000000000000000000024030 0ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/mapper_tests/test_build.py0000644000655200065520000005554300000000000024455 0ustar00circlecicirclecifrom abc import ABCMeta, abstractmethod import numpy as np from hdmf import Container, Data from hdmf.build import ObjectMapper, BuildManager, TypeMap, GroupBuilder, DatasetBuilder from hdmf.build.warnings import DtypeConversionWarning from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, Spec from hdmf.testing import TestCase from hdmf.utils import docval, getargs from tests.unit.utils import CORE_NAMESPACE # TODO: test build of extended group/dataset that modifies an attribute dtype (commented out below), shape, value, etc. # by restriction. also check that attributes cannot be deleted or scope expanded. # TODO: test build of extended dataset that modifies shape by restriction. class Bar(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this Bar'}, {'name': 'attr1', 'type': str, 'doc': 'a string attribute'}, {'name': 'attr2', 'type': 'int', 'doc': 'an int attribute', 'default': None}, {'name': 'ext_attr', 'type': bool, 'doc': 'a boolean attribute', 'default': True}) def __init__(self, **kwargs): name, attr1, attr2, ext_attr = getargs('name', 'attr1', 'attr2', 'ext_attr', kwargs) super().__init__(name=name) self.__attr1 = attr1 self.__attr2 = attr2 self.__ext_attr = kwargs['ext_attr'] @property def data_type(self): return 'Bar' @property def attr1(self): return self.__attr1 @property def attr2(self): return self.__attr2 @property def ext_attr(self): return self.__ext_attr class BarHolder(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this BarHolder'}, {'name': 'bars', 'type': ('data', 'array_data'), 'doc': 'bars', 'default': list()}) def __init__(self, **kwargs): name, bars = getargs('name', 'bars', kwargs) super().__init__(name=name) self.__bars = bars for b in bars: if b is not None and b.parent is None: b.parent = self @property def data_type(self): return 'BarHolder' @property def bars(self): return self.__bars class ExtBarMapper(ObjectMapper): @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, {"name": "container", "type": Bar, "doc": "the container to get the attribute value from"}, {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, returns='the value of the attribute') def get_attr_value(self, **kwargs): ''' Get the value of the attribute corresponding to this spec from the given container ''' spec, container, manager = getargs('spec', 'container', 'manager', kwargs) # handle custom mapping of field 'ext_attr' within container BarHolder/Bar -> spec BarHolder/Bar.ext_attr if isinstance(container.parent, BarHolder): if spec.name == 'ext_attr': return container.ext_attr return super().get_attr_value(**kwargs) class BuildGroupExtAttrsMixin(TestCase, metaclass=ABCMeta): def setUp(self): self.setUpBarSpec() self.setUpBarHolderSpec() spec_catalog = SpecCatalog() spec_catalog.register_spec(self.bar_spec, 'test.yaml') spec_catalog.register_spec(self.bar_holder_spec, 'test.yaml') namespace = SpecNamespace( doc='a test namespace', name=CORE_NAMESPACE, schema=[{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) type_map.register_container_type(CORE_NAMESPACE, 'BarHolder', BarHolder) type_map.register_map(Bar, ExtBarMapper) type_map.register_map(BarHolder, ObjectMapper) self.manager = BuildManager(type_map) def setUpBarSpec(self): attr1_attr = AttributeSpec( name='attr1', dtype='text', doc='an example string attribute', ) attr2_attr = AttributeSpec( name='attr2', dtype='int', doc='an example int attribute', ) self.bar_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Bar', attributes=[attr1_attr, attr2_attr], ) @abstractmethod def setUpBarHolderSpec(self): pass class TestBuildGroupAddedAttr(BuildGroupExtAttrsMixin, TestCase): """ If the spec defines a group data_type A (Bar) using 'data_type_def' and defines another data_type B (BarHolder) that includes A using 'data_type_inc', then the included A spec is an extended (or refined) spec of A - call it A'. The spec of A' can refine or add attributes to the spec of A. This test ensures that *added attributes* in A' are handled properly. """ def setUpBarHolderSpec(self): ext_attr = AttributeSpec( name='ext_attr', dtype='bool', doc='A boolean attribute', ) bar_ext_no_name_spec = GroupSpec( doc='A Bar extended with attribute ext_attr', data_type_inc='Bar', quantity='*', attributes=[ext_attr], ) self.bar_holder_spec = GroupSpec( doc='A container of multiple extended Bar objects', data_type_def='BarHolder', groups=[bar_ext_no_name_spec], ) def test_build_added_attr(self): """ Test build of BarHolder which can contain multiple extended Bar objects, which have a new attribute. """ ext_bar_inst = Bar( name='my_bar', attr1='a string', attr2=10, ext_attr=False, ) bar_holder_inst = BarHolder( name='my_bar_holder', bars=[ext_bar_inst], ) expected_inner = GroupBuilder( name='my_bar', attributes={ 'attr1': 'a string', 'attr2': 10, 'data_type': 'Bar', 'ext_attr': False, 'namespace': CORE_NAMESPACE, 'object_id': ext_bar_inst.object_id, }, ) expected = GroupBuilder( name='my_bar_holder', groups={'my_bar': expected_inner}, attributes={ 'data_type': 'BarHolder', 'namespace': CORE_NAMESPACE, 'object_id': bar_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field builder = self.manager.build(bar_holder_inst, source='test.h5') self.assertDictEqual(builder, expected) class TestBuildGroupRefinedAttr(BuildGroupExtAttrsMixin, TestCase): """ If the spec defines a group data_type A (Bar) using 'data_type_def' and defines another data_type B (BarHolder) that includes A using 'data_type_inc', then the included A spec is an extended (or refined) spec of A - call it A'. The spec of A' can refine or add attributes to the spec of A. This test ensures that *refine attributes* in A' are handled properly. """ def setUpBarHolderSpec(self): int_attr2 = AttributeSpec( name='attr2', dtype='int64', doc='Refine Bar spec from int to int64', ) bar_ext_no_name_spec = GroupSpec( doc='A Bar extended with modified attribute attr2', data_type_inc='Bar', quantity='*', attributes=[int_attr2], ) self.bar_holder_spec = GroupSpec( doc='A container of multiple extended Bar objects', data_type_def='BarHolder', groups=[bar_ext_no_name_spec], ) def test_build_refined_attr(self): """ Test build of BarHolder which can contain multiple extended Bar objects, which have a modified attr2. """ ext_bar_inst = Bar( name='my_bar', attr1='a string', attr2=np.int64(10), ) bar_holder_inst = BarHolder( name='my_bar_holder', bars=[ext_bar_inst], ) expected_inner = GroupBuilder( name='my_bar', attributes={ 'attr1': 'a string', 'attr2': np.int64(10), 'data_type': 'Bar', 'namespace': CORE_NAMESPACE, 'object_id': ext_bar_inst.object_id, } ) expected = GroupBuilder( name='my_bar_holder', groups={'my_bar': expected_inner}, attributes={ 'data_type': 'BarHolder', 'namespace': CORE_NAMESPACE, 'object_id': bar_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field builder = self.manager.build(bar_holder_inst, source='test.h5') self.assertDictEqual(builder, expected) # def test_build_refined_attr_wrong_type(self): # """ # Test build of BarHolder which contains a Bar that has the wrong dtype for an attr. # """ # ext_bar_inst = Bar( # name='my_bar', # attr1='a string', # attr2=10, # spec specifies attr2 should be an int64 for Bars within BarHolder # ) # bar_holder_inst = BarHolder( # name='my_bar_holder', # bars=[ext_bar_inst], # ) # # expected_inner = GroupBuilder( # name='my_bar', # attributes={ # 'attr1': 'a string', # 'attr2': np.int64(10), # 'data_type': 'Bar', # 'namespace': CORE_NAMESPACE, # 'object_id': ext_bar_inst.object_id, # } # ) # expected = GroupBuilder( # name='my_bar_holder', # groups={'my_bar': expected_inner}, # attributes={ # 'data_type': 'BarHolder', # 'namespace': CORE_NAMESPACE, # 'object_id': bar_holder_inst.object_id, # }, # ) # # # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field # # # TODO build should raise a conversion warning for converting 10 (int32) to np.int64 # builder = self.manager.build(bar_holder_inst, source='test.h5') # self.assertDictEqual(builder, expected) class BarData(Data): @docval({'name': 'name', 'type': str, 'doc': 'the name of this BarData'}, {'name': 'data', 'type': ('data', 'array_data'), 'doc': 'the data'}, {'name': 'attr1', 'type': str, 'doc': 'a string attribute'}, {'name': 'attr2', 'type': 'int', 'doc': 'an int attribute', 'default': None}, {'name': 'ext_attr', 'type': bool, 'doc': 'a boolean attribute', 'default': True}) def __init__(self, **kwargs): name, data, attr1, attr2, ext_attr = getargs('name', 'data', 'attr1', 'attr2', 'ext_attr', kwargs) super().__init__(name=name, data=data) self.__attr1 = attr1 self.__attr2 = attr2 self.__ext_attr = kwargs['ext_attr'] @property def data_type(self): return 'BarData' @property def attr1(self): return self.__attr1 @property def attr2(self): return self.__attr2 @property def ext_attr(self): return self.__ext_attr class BarDataHolder(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this BarDataHolder'}, {'name': 'bar_datas', 'type': ('data', 'array_data'), 'doc': 'bar_datas', 'default': list()}) def __init__(self, **kwargs): name, bar_datas = getargs('name', 'bar_datas', kwargs) super().__init__(name=name) self.__bar_datas = bar_datas for b in bar_datas: if b is not None and b.parent is None: b.parent = self @property def data_type(self): return 'BarDataHolder' @property def bar_datas(self): return self.__bar_datas class ExtBarDataMapper(ObjectMapper): @docval({"name": "spec", "type": Spec, "doc": "the spec to get the attribute value for"}, {"name": "container", "type": BarData, "doc": "the container to get the attribute value from"}, {"name": "manager", "type": BuildManager, "doc": "the BuildManager used for managing this build"}, returns='the value of the attribute') def get_attr_value(self, **kwargs): ''' Get the value of the attribute corresponding to this spec from the given container ''' spec, container, manager = getargs('spec', 'container', 'manager', kwargs) # handle custom mapping of field 'ext_attr' within container # BardataHolder/BarData -> spec BarDataHolder/BarData.ext_attr if isinstance(container.parent, BarDataHolder): if spec.name == 'ext_attr': return container.ext_attr return super().get_attr_value(**kwargs) class BuildDatasetExtAttrsMixin(TestCase, metaclass=ABCMeta): def setUp(self): self.set_up_specs() spec_catalog = SpecCatalog() spec_catalog.register_spec(self.bar_data_spec, 'test.yaml') spec_catalog.register_spec(self.bar_data_holder_spec, 'test.yaml') namespace = SpecNamespace( doc='a test namespace', name=CORE_NAMESPACE, schema=[{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) type_map.register_container_type(CORE_NAMESPACE, 'BarData', BarData) type_map.register_container_type(CORE_NAMESPACE, 'BarDataHolder', BarDataHolder) type_map.register_map(BarData, ExtBarDataMapper) type_map.register_map(BarDataHolder, ObjectMapper) self.manager = BuildManager(type_map) def set_up_specs(self): attr1_attr = AttributeSpec( name='attr1', dtype='text', doc='an example string attribute', ) attr2_attr = AttributeSpec( name='attr2', dtype='int', doc='an example int attribute', ) self.bar_data_spec = DatasetSpec( doc='A test dataset specification with a data type', data_type_def='BarData', dtype='int', shape=[[None], [None, None]], attributes=[attr1_attr, attr2_attr], ) self.bar_data_holder_spec = GroupSpec( doc='A container of multiple extended BarData objects', data_type_def='BarDataHolder', datasets=[self.get_refined_bar_data_spec()], ) @abstractmethod def get_refined_bar_data_spec(self): pass class TestBuildDatasetAddedAttrs(BuildDatasetExtAttrsMixin, TestCase): """ If the spec defines a dataset data_type A (BarData) using 'data_type_def' and defines another data_type B (BarHolder) that includes A using 'data_type_inc', then the included A spec is an extended (or refined) spec of A - call it A'. The spec of A' can refine or add attributes, refine the dtype, refine the shape, or set a fixed value to the spec of A. This test ensures that *added attributes* in A' are handled properly. This is similar to how the spec for a subtype of DynamicTable can contain a VectorData that has a new attribute. """ def get_refined_bar_data_spec(self): ext_attr = AttributeSpec( name='ext_attr', dtype='bool', doc='A boolean attribute', ) refined_spec = DatasetSpec( doc='A BarData extended with attribute ext_attr', data_type_inc='BarData', quantity='*', attributes=[ext_attr], ) return refined_spec def test_build_added_attr(self): """ Test build of BarHolder which can contain multiple extended BarData objects, which have a new attribute. """ ext_bar_data_inst = BarData( name='my_bar', data=list(range(10)), attr1='a string', attr2=10, ext_attr=False, ) bar_data_holder_inst = BarDataHolder( name='my_bar_holder', bar_datas=[ext_bar_data_inst], ) expected_inner = DatasetBuilder( name='my_bar', data=list(range(10)), attributes={ 'attr1': 'a string', 'attr2': 10, 'data_type': 'BarData', 'ext_attr': False, 'namespace': CORE_NAMESPACE, 'object_id': ext_bar_data_inst.object_id, }, ) expected = GroupBuilder( name='my_bar_holder', datasets={'my_bar': expected_inner}, attributes={ 'data_type': 'BarDataHolder', 'namespace': CORE_NAMESPACE, 'object_id': bar_data_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field builder = self.manager.build(bar_data_holder_inst, source='test.h5') self.assertDictEqual(builder, expected) class TestBuildDatasetRefinedDtype(BuildDatasetExtAttrsMixin, TestCase): """ If the spec defines a dataset data_type A (BarData) using 'data_type_def' and defines another data_type B (BarHolder) that includes A using 'data_type_inc', then the included A spec is an extended (or refined) spec of A - call it A'. The spec of A' can refine or add attributes, refine the dtype, refine the shape, or set a fixed value to the spec of A. This test ensures that if A' refines the dtype of A, the build process uses the correct dtype for conversion. """ def get_refined_bar_data_spec(self): refined_spec = DatasetSpec( doc='A BarData with refined int64 dtype', data_type_inc='BarData', dtype='int64', quantity='*', ) return refined_spec def test_build_refined_dtype_convert(self): """ Test build of BarDataHolder which contains a BarData with data that needs to be converted to the refined dtype. """ ext_bar_data_inst = BarData( name='my_bar', data=np.array([1, 2], dtype=np.int32), # the refined spec says data should be int64s attr1='a string', attr2=10, ) bar_data_holder_inst = BarDataHolder( name='my_bar_holder', bar_datas=[ext_bar_data_inst], ) expected_inner = DatasetBuilder( name='my_bar', data=np.array([1, 2], dtype=np.int64), # the objectmapper should convert the given data to int64s attributes={ 'attr1': 'a string', 'attr2': 10, 'data_type': 'BarData', 'namespace': CORE_NAMESPACE, 'object_id': ext_bar_data_inst.object_id, }, ) expected = GroupBuilder( name='my_bar_holder', datasets={'my_bar': expected_inner}, attributes={ 'data_type': 'BarDataHolder', 'namespace': CORE_NAMESPACE, 'object_id': bar_data_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field msg = ("Spec 'BarDataHolder/BarData': Value with data type int32 is being converted to data type int64 " "as specified.") with self.assertWarnsWith(DtypeConversionWarning, msg): builder = self.manager.build(bar_data_holder_inst, source='test.h5') np.testing.assert_array_equal(builder.datasets['my_bar'].data, expected.datasets['my_bar'].data) self.assertEqual(builder.datasets['my_bar'].data.dtype, np.int64) class TestBuildDatasetNotRefinedDtype(BuildDatasetExtAttrsMixin, TestCase): """ If the spec defines a dataset data_type A (BarData) using 'data_type_def' and defines another data_type B (BarHolder) that includes A using 'data_type_inc', then the included A spec is an extended (or refined) spec of A - call it A'. The spec of A' can refine or add attributes, refine the dtype, refine the shape, or set a fixed value to the spec of A. This test ensures that if A' does not refine the dtype of A, the build process uses the correct dtype for conversion. """ def get_refined_bar_data_spec(self): refined_spec = DatasetSpec( doc='A BarData', data_type_inc='BarData', quantity='*', ) return refined_spec def test_build_correct_dtype(self): """ Test build of BarDataHolder which contains a BarData. """ ext_bar_data_inst = BarData( name='my_bar', data=[1, 2], attr1='a string', attr2=10, ) bar_data_holder_inst = BarDataHolder( name='my_bar_holder', bar_datas=[ext_bar_data_inst], ) expected_inner = DatasetBuilder( name='my_bar', data=[1, 2], attributes={ 'attr1': 'a string', 'attr2': 10, 'data_type': 'BarData', 'namespace': CORE_NAMESPACE, 'object_id': ext_bar_data_inst.object_id, }, ) expected = GroupBuilder( name='my_bar_holder', datasets={'my_bar': expected_inner}, attributes={ 'data_type': 'BarDataHolder', 'namespace': CORE_NAMESPACE, 'object_id': bar_data_holder_inst.object_id, }, ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field builder = self.manager.build(bar_data_holder_inst, source='test.h5') self.assertDictEqual(builder, expected) def test_build_incorrect_dtype(self): """ Test build of BarDataHolder which contains a BarData """ ext_bar_data_inst = BarData( name='my_bar', data=['a', 'b'], attr1='a string', attr2=10, ) bar_data_holder_inst = BarDataHolder( name='my_bar_holder', bar_datas=[ext_bar_data_inst], ) # the object mapper automatically maps the spec of extended Bars to the 'BarMapper.bars' field msg = "could not resolve dtype for BarData 'my_bar'" with self.assertRaisesWith(Exception, msg): self.manager.build(bar_data_holder_inst, source='test.h5') ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/mapper_tests/test_build_quantity.py0000644000655200065520000013253000000000000026403 0ustar00circlecicirclecifrom hdmf import Container, Data from hdmf.build import (BuildManager, TypeMap, GroupBuilder, DatasetBuilder, LinkBuilder, ObjectMapper, MissingRequiredBuildWarning, IncorrectQuantityBuildWarning) from hdmf.spec import GroupSpec, DatasetSpec, LinkSpec, SpecCatalog, SpecNamespace, NamespaceCatalog from hdmf.spec.spec import ZERO_OR_MANY, ONE_OR_MANY, ZERO_OR_ONE, DEF_QUANTITY from hdmf.testing import TestCase from hdmf.utils import docval, getargs from tests.unit.utils import CORE_NAMESPACE ########################## # test all crosses: # { # untyped, named group with data-type-included groups / data-type-included datasets / links # nested, type definition # included groups / included datasets / links # } # x # group/dataset/link with quantity {'*', '+', 1, 2, '?'} # x # builder with 2, 1, or 0 instances of the type, or 0 instances of the type with some instances of a mismatched type class SimpleFoo(Container): pass class NotSimpleFoo(Container): pass class SimpleQux(Data): pass class NotSimpleQux(Data): pass class SimpleBucket(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this SimpleBucket'}, {'name': 'foos', 'type': list, 'doc': 'the SimpleFoo objects', 'default': list()}, {'name': 'quxs', 'type': list, 'doc': 'the SimpleQux objects', 'default': list()}, {'name': 'links', 'type': list, 'doc': 'another way to store SimpleFoo objects', 'default': list()}) def __init__(self, **kwargs): name, foos, quxs, links = getargs('name', 'foos', 'quxs', 'links', kwargs) super().__init__(name=name) # note: collections of groups are unordered in HDF5, so make these dictionaries for keyed access self.foos = {f.name: f for f in foos} for f in foos: f.parent = self self.quxs = {q.name: q for q in quxs} for q in quxs: q.parent = self self.links = {i.name: i for i in links} for i in links: i.parent = self class BasicBucket(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this BasicBucket'}, {'name': 'untyped_dataset', 'type': 'scalar_data', 'doc': 'a scalar dataset within this BasicBucket', 'default': None}, {'name': 'untyped_array_dataset', 'type': ('data', 'array_data'), 'doc': 'an array dataset within this BasicBucket', 'default': None},) def __init__(self, **kwargs): name, untyped_dataset, untyped_array_dataset = getargs('name', 'untyped_dataset', 'untyped_array_dataset', kwargs) super().__init__(name=name) self.untyped_dataset = untyped_dataset self.untyped_array_dataset = untyped_array_dataset class BuildQuantityMixin: """Base test class mixin to set up the BuildManager.""" def setUpManager(self, specs): spec_catalog = SpecCatalog() schema_file = 'test.yaml' for s in specs: spec_catalog.register_spec(s, schema_file) namespace = SpecNamespace( doc='a test namespace', name=CORE_NAMESPACE, schema=[{'source': schema_file}], version='0.1.0', catalog=spec_catalog ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) type_map.register_container_type(CORE_NAMESPACE, 'SimpleFoo', SimpleFoo) type_map.register_container_type(CORE_NAMESPACE, 'NotSimpleFoo', NotSimpleFoo) type_map.register_container_type(CORE_NAMESPACE, 'SimpleQux', SimpleQux) type_map.register_container_type(CORE_NAMESPACE, 'NotSimpleQux', NotSimpleQux) type_map.register_container_type(CORE_NAMESPACE, 'SimpleBucket', SimpleBucket) type_map.register_map(SimpleBucket, self.setUpBucketMapper()) self.manager = BuildManager(type_map) def _create_builder(self, container): """Helper function to get a basic builder for a container with no subgroups/datasets/links.""" if isinstance(container, Container): ret = GroupBuilder( name=container.name, attributes={'namespace': container.namespace, 'data_type': container.data_type, 'object_id': container.object_id} ) else: ret = DatasetBuilder( name=container.name, data=container.data, attributes={'namespace': container.namespace, 'data_type': container.data_type, 'object_id': container.object_id} ) return ret class TypeIncUntypedGroupMixin: def create_specs(self, quantity): # Type SimpleBucket contains: # - an untyped group "foo_holder" which contains [quantity] groups of data_type_inc SimpleFoo # - an untyped group "qux_holder" which contains [quantity] datasets of data_type_inc SimpleQux # - an untyped group "link_holder" which contains [quantity] links of target_type SimpleFoo foo_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='SimpleFoo', ) not_foo_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='NotSimpleFoo', ) qux_spec = DatasetSpec( doc='A test group specification with a data type', data_type_def='SimpleQux', ) not_qux_spec = DatasetSpec( doc='A test group specification with a data type', data_type_def='NotSimpleQux', ) foo_inc_spec = GroupSpec( doc='the SimpleFoos in this bucket', data_type_inc='SimpleFoo', quantity=quantity ) foo_holder_spec = GroupSpec( doc='An untyped subgroup for SimpleFoos', name='foo_holder', groups=[foo_inc_spec] ) qux_inc_spec = DatasetSpec( doc='the SimpleQuxs in this bucket', data_type_inc='SimpleQux', quantity=quantity ) qux_holder_spec = GroupSpec( doc='An untyped subgroup for SimpleQuxs', name='qux_holder', datasets=[qux_inc_spec] ) foo_link_spec = LinkSpec( doc='the links in this bucket', target_type='SimpleFoo', quantity=quantity ) link_holder_spec = GroupSpec( doc='An untyped subgroup for links', name='link_holder', links=[foo_link_spec] ) bucket_spec = GroupSpec( doc='A test group specification for a data type containing data type', name="test_bucket", data_type_def='SimpleBucket', groups=[foo_holder_spec, qux_holder_spec, link_holder_spec] ) return [foo_spec, not_foo_spec, qux_spec, not_qux_spec, bucket_spec] def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) self.unmap(spec.get_group('foo_holder')) self.map_spec('foos', spec.get_group('foo_holder').get_data_type('SimpleFoo')) self.unmap(spec.get_group('qux_holder')) self.map_spec('quxs', spec.get_group('qux_holder').get_data_type('SimpleQux')) self.unmap(spec.get_group('link_holder')) self.map_spec('links', spec.get_group('link_holder').links[0]) return BucketMapper def get_two_bucket_test(self): foos = [SimpleFoo('my_foo1'), SimpleFoo('my_foo2')] quxs = [SimpleQux('my_qux1', data=[1, 2, 3]), SimpleQux('my_qux2', data=[4, 5, 6])] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, links=foos ) foo1_builder = self._create_builder(bucket.foos['my_foo1']) foo2_builder = self._create_builder(bucket.foos['my_foo2']) qux1_builder = self._create_builder(bucket.quxs['my_qux1']) qux2_builder = self._create_builder(bucket.quxs['my_qux2']) foo_holder_builder = GroupBuilder( name='foo_holder', groups={'my_foo1': foo1_builder, 'my_foo2': foo2_builder} ) qux_holder_builder = GroupBuilder( name='qux_holder', datasets={'my_qux1': qux1_builder, 'my_qux2': qux2_builder} ) foo1_link_builder = LinkBuilder(builder=foo1_builder) foo2_link_builder = LinkBuilder(builder=foo2_builder) link_holder_builder = GroupBuilder( name='link_holder', links={'my_foo1': foo1_link_builder, 'my_foo2': foo2_link_builder} ) bucket_builder = GroupBuilder( name='test_bucket', groups={'foos': foo_holder_builder, 'quxs': qux_holder_builder, 'links': link_holder_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def get_one_bucket_test(self): foos = [SimpleFoo('my_foo1')] quxs = [SimpleQux('my_qux1', data=[1, 2, 3])] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, links=foos ) foo1_builder = GroupBuilder( name='my_foo1', attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleFoo', 'object_id': bucket.foos['my_foo1'].object_id} ) foo_holder_builder = GroupBuilder( name='foo_holder', groups={'my_foo1': foo1_builder} ) qux1_builder = DatasetBuilder( name='my_qux1', data=[1, 2, 3], attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleQux', 'object_id': bucket.quxs['my_qux1'].object_id} ) qux_holder_builder = GroupBuilder( name='qux_holder', datasets={'my_qux1': qux1_builder} ) foo1_link_builder = LinkBuilder(builder=foo1_builder) link_holder_builder = GroupBuilder( name='link_holder', links={'my_foo1': foo1_link_builder} ) bucket_builder = GroupBuilder( name='test_bucket', groups={'foos': foo_holder_builder, 'quxs': qux_holder_builder, 'links': link_holder_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def get_zero_bucket_test(self): bucket = SimpleBucket( name='test_bucket' ) foo_holder_builder = GroupBuilder( name='foo_holder', groups={} ) qux_holder_builder = GroupBuilder( name='qux_holder', datasets={} ) link_holder_builder = GroupBuilder( name='link_holder', links={} ) bucket_builder = GroupBuilder( name='test_bucket', groups={'foos': foo_holder_builder, 'quxs': qux_holder_builder, 'links': link_holder_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def get_mismatch_bucket_test(self): foos = [NotSimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] quxs = [NotSimpleQux('my_qux1', data=[1, 2, 3]), NotSimpleQux('my_qux2', data=[4, 5, 6])] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, links=foos ) foo_holder_builder = GroupBuilder( name='foo_holder', groups={} ) qux_holder_builder = GroupBuilder( name='qux_holder', datasets={} ) link_holder_builder = GroupBuilder( name='link_holder', links={} ) bucket_builder = GroupBuilder( name='test_bucket', groups={'foos': foo_holder_builder, 'quxs': qux_holder_builder, 'links': link_holder_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder class TypeDefMixin: def create_specs(self, quantity): # Type SimpleBucket contains: # - contains [quantity] groups of data_type_def SimpleFoo # - contains [quantity] datasets of data_type_def SimpleQux # NOTE: links do not have data_type_def, so leave them out of these tests # NOTE: nested type definitions are strongly discouraged now foo_spec = GroupSpec( doc='the SimpleFoos in this bucket', data_type_def='SimpleFoo', quantity=quantity ) qux_spec = DatasetSpec( doc='the SimpleQuxs in this bucket', data_type_def='SimpleQux', quantity=quantity ) not_foo_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='NotSimpleFoo', ) not_qux_spec = DatasetSpec( doc='A test group specification with a data type', data_type_def='NotSimpleQux', ) bucket_spec = GroupSpec( doc='A test group specification for a data type containing data type', name="test_bucket", data_type_def='SimpleBucket', groups=[foo_spec], datasets=[qux_spec] ) return [foo_spec, not_foo_spec, qux_spec, not_qux_spec, bucket_spec] def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) self.map_spec('foos', spec.get_data_type('SimpleFoo')) self.map_spec('quxs', spec.get_data_type('SimpleQux')) return BucketMapper def get_two_bucket_test(self): foos = [SimpleFoo('my_foo1'), SimpleFoo('my_foo2')] quxs = [SimpleQux('my_qux1', data=[1, 2, 3]), SimpleQux('my_qux2', data=[4, 5, 6])] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, ) foo1_builder = self._create_builder(bucket.foos['my_foo1']) foo2_builder = self._create_builder(bucket.foos['my_foo2']) qux1_builder = self._create_builder(bucket.quxs['my_qux1']) qux2_builder = self._create_builder(bucket.quxs['my_qux2']) bucket_builder = GroupBuilder( name='test_bucket', groups={'my_foo1': foo1_builder, 'my_foo2': foo2_builder}, datasets={'my_qux1': qux1_builder, 'my_qux2': qux2_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def get_one_bucket_test(self): foos = [SimpleFoo('my_foo1')] quxs = [SimpleQux('my_qux1', data=[1, 2, 3])] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, ) foo1_builder = self._create_builder(bucket.foos['my_foo1']) qux1_builder = self._create_builder(bucket.quxs['my_qux1']) bucket_builder = GroupBuilder( name='test_bucket', groups={'my_foo1': foo1_builder}, datasets={'my_qux1': qux1_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def get_zero_bucket_test(self): bucket = SimpleBucket( name='test_bucket' ) bucket_builder = GroupBuilder( name='test_bucket', attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def get_mismatch_bucket_test(self): foos = [NotSimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] quxs = [NotSimpleQux('my_qux1', data=[1, 2, 3]), NotSimpleQux('my_qux2', data=[4, 5, 6])] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, ) bucket_builder = GroupBuilder( name='test_bucket', attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder class TypeIncMixin: def create_specs(self, quantity): # Type SimpleBucket contains: # - [quantity] groups of data_type_inc SimpleFoo # - [quantity] datasets of data_type_inc SimpleQux # - [quantity] links of target_type SimpleFoo foo_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='SimpleFoo', ) not_foo_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='NotSimpleFoo', ) qux_spec = DatasetSpec( doc='A test group specification with a data type', data_type_def='SimpleQux', ) not_qux_spec = DatasetSpec( doc='A test group specification with a data type', data_type_def='NotSimpleQux', ) foo_inc_spec = GroupSpec( doc='the SimpleFoos in this bucket', data_type_inc='SimpleFoo', quantity=quantity ) qux_inc_spec = DatasetSpec( doc='the SimpleQuxs in this bucket', data_type_inc='SimpleQux', quantity=quantity ) foo_link_spec = LinkSpec( doc='the links in this bucket', target_type='SimpleFoo', quantity=quantity ) bucket_spec = GroupSpec( doc='A test group specification for a data type containing data type', name="test_bucket", data_type_def='SimpleBucket', groups=[foo_inc_spec], datasets=[qux_inc_spec], links=[foo_link_spec] ) return [foo_spec, not_foo_spec, qux_spec, not_qux_spec, bucket_spec] def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) self.map_spec('foos', spec.get_data_type('SimpleFoo')) self.map_spec('quxs', spec.get_data_type('SimpleQux')) self.map_spec('links', spec.links[0]) return BucketMapper def get_two_bucket_test(self): foos = [SimpleFoo('my_foo1'), SimpleFoo('my_foo2')] quxs = [SimpleQux('my_qux1', data=[1, 2, 3]), SimpleQux('my_qux2', data=[4, 5, 6])] # NOTE: unlike in the other tests, links cannot map to the same foos in bucket because of a name clash links = [SimpleFoo('my_foo3'), SimpleFoo('my_foo4')] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, links=links ) foo1_builder = self._create_builder(bucket.foos['my_foo1']) foo2_builder = self._create_builder(bucket.foos['my_foo2']) foo3_builder = self._create_builder(bucket.links['my_foo3']) foo4_builder = self._create_builder(bucket.links['my_foo4']) qux1_builder = self._create_builder(bucket.quxs['my_qux1']) qux2_builder = self._create_builder(bucket.quxs['my_qux2']) foo3_link_builder = LinkBuilder(builder=foo3_builder) foo4_link_builder = LinkBuilder(builder=foo4_builder) bucket_builder = GroupBuilder( name='test_bucket', groups={'my_foo1': foo1_builder, 'my_foo2': foo2_builder}, datasets={'my_qux1': qux1_builder, 'my_qux2': qux2_builder}, links={'my_foo3': foo3_link_builder, 'my_foo4': foo4_link_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def get_one_bucket_test(self): foos = [SimpleFoo('my_foo1')] quxs = [SimpleQux('my_qux1', data=[1, 2, 3])] # NOTE: unlike in the other tests, links cannot map to the same foos in bucket because of a name clash links = [SimpleFoo('my_foo3')] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, links=links ) foo1_builder = self._create_builder(bucket.foos['my_foo1']) foo3_builder = self._create_builder(bucket.links['my_foo3']) qux1_builder = self._create_builder(bucket.quxs['my_qux1']) foo3_link_builder = LinkBuilder(builder=foo3_builder) bucket_builder = GroupBuilder( name='test_bucket', groups={'my_foo1': foo1_builder}, datasets={'my_qux1': qux1_builder}, links={'my_foo1': foo3_link_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def get_zero_bucket_test(self): bucket = SimpleBucket( name='test_bucket' ) bucket_builder = GroupBuilder( name='test_bucket', attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def get_mismatch_bucket_test(self): foos = [NotSimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] quxs = [NotSimpleQux('my_qux1', data=[1, 2, 3]), NotSimpleQux('my_qux2', data=[4, 5, 6])] links = [NotSimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, links=links ) bucket_builder = GroupBuilder( name='test_bucket', attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder class ZeroOrManyMixin: def setUp(self): specs = self.create_specs(ZERO_OR_MANY) self.setUpManager(specs) def test_build_two(self): """Test building a container which contains multiple containers as the spec allows.""" bucket, bucket_builder = self.get_two_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_one(self): """Test building a container which contains one container as the spec allows.""" bucket, bucket_builder = self.get_one_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_zero(self): """Test building a container which contains no containers as the spec allows.""" bucket, bucket_builder = self.get_zero_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_mismatch(self): """Test building a container which contains no containers that match the spec as the spec allows.""" bucket, bucket_builder = self.get_mismatch_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) class OneOrManyMixin: def setUp(self): specs = self.create_specs(ONE_OR_MANY) self.setUpManager(specs) def test_build_two(self): """Test building a container which contains multiple containers as the spec allows.""" bucket, bucket_builder = self.get_two_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_one(self): """Test building a container which contains one container as the spec allows.""" bucket, bucket_builder = self.get_one_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_zero(self): """Test building a container which contains no containers as the spec allows.""" bucket, bucket_builder = self.get_zero_bucket_test() msg = r"SimpleBucket 'test_bucket' is missing required value for attribute '.*'\." with self.assertWarnsRegex(MissingRequiredBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_mismatch(self): """Test building a container which contains no containers that match the spec as the spec allows.""" bucket, bucket_builder = self.get_mismatch_bucket_test() msg = r"SimpleBucket 'test_bucket' is missing required value for attribute '.*'\." with self.assertWarnsRegex(MissingRequiredBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) class OneMixin: def setUp(self): specs = self.create_specs(DEF_QUANTITY) self.setUpManager(specs) def test_build_two(self): """Test building a container which contains multiple containers as the spec allows.""" bucket, bucket_builder = self.get_two_bucket_test() msg = r"SimpleBucket 'test_bucket' has 2 values for attribute '.*' but spec allows 1\." with self.assertWarnsRegex(IncorrectQuantityBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_one(self): """Test building a container which contains one container as the spec allows.""" bucket, bucket_builder = self.get_one_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_zero(self): """Test building a container which contains no containers as the spec allows.""" bucket, bucket_builder = self.get_zero_bucket_test() msg = r"SimpleBucket 'test_bucket' is missing required value for attribute '.*'\." with self.assertWarnsRegex(MissingRequiredBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_mismatch(self): """Test building a container which contains no containers that match the spec as the spec allows.""" bucket, bucket_builder = self.get_mismatch_bucket_test() msg = r"SimpleBucket 'test_bucket' is missing required value for attribute '.*'\." with self.assertWarnsRegex(MissingRequiredBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) class TwoMixin: def setUp(self): specs = self.create_specs(2) self.setUpManager(specs) def test_build_two(self): """Test building a container which contains multiple containers as the spec allows.""" bucket, bucket_builder = self.get_two_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_one(self): """Test building a container which contains one container as the spec allows.""" bucket, bucket_builder = self.get_one_bucket_test() msg = r"SimpleBucket 'test_bucket' has 1 values for attribute '.*' but spec allows 2\." with self.assertWarnsRegex(IncorrectQuantityBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_zero(self): """Test building a container which contains no containers as the spec allows.""" bucket, bucket_builder = self.get_zero_bucket_test() msg = r"SimpleBucket 'test_bucket' is missing required value for attribute '.*'\." with self.assertWarnsRegex(MissingRequiredBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_mismatch(self): """Test building a container which contains no containers that match the spec as the spec allows.""" bucket, bucket_builder = self.get_mismatch_bucket_test() msg = r"SimpleBucket 'test_bucket' is missing required value for attribute '.*'\." with self.assertWarnsRegex(MissingRequiredBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) class ZeroOrOneMixin: def setUp(self): specs = self.create_specs(ZERO_OR_ONE) self.setUpManager(specs) def test_build_two(self): """Test building a container which contains multiple containers as the spec allows.""" bucket, bucket_builder = self.get_two_bucket_test() msg = r"SimpleBucket 'test_bucket' has 2 values for attribute '.*' but spec allows '\?'\." with self.assertWarnsRegex(IncorrectQuantityBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_one(self): """Test building a container which contains one container as the spec allows.""" bucket, bucket_builder = self.get_one_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_zero(self): """Test building a container which contains no containers as the spec allows.""" bucket, bucket_builder = self.get_zero_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_mismatch(self): """Test building a container which contains no containers that match the spec as the spec allows.""" bucket, bucket_builder = self.get_mismatch_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) # Untyped group with included groups / included datasets / links with quantity {'*', '+', 1, 2 '?'} class TestBuildZeroOrManyTypeIncUntypedGroup(ZeroOrManyMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity '*' """ pass class TestBuildOneOrManyTypeIncUntypedGroup(OneOrManyMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity '+' """ pass class TestBuildOneTypeIncUntypedGroup(OneMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity 1 """ pass class TestBuildTwoTypeIncUntypedGroup(TwoMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity 2 """ pass class TestBuildZeroOrOneTypeIncUntypedGroup(ZeroOrOneMixin, TypeIncUntypedGroupMixin, BuildQuantityMixin, TestCase): """Test building a group that has an untyped subgroup with a data type inc subgroup/dataset/link with quantity '?' """ pass # Nested type definition of group/dataset with quantity {'*', '+', 1, 2, '?'} class TestBuildZeroOrManyTypeDef(ZeroOrManyMixin, TypeDefMixin, BuildQuantityMixin, TestCase): """Test building a group that has a nested type def with quantity '*' """ pass class TestBuildOneOrManyTypeDef(OneOrManyMixin, TypeDefMixin, BuildQuantityMixin, TestCase): """Test building a group that has a nested type def with quantity '+' """ pass class TestBuildOneTypeDef(OneMixin, TypeDefMixin, BuildQuantityMixin, TestCase): """Test building a group that has a nested type def with quantity 1 """ pass class TestBuildTwoTypeDef(TwoMixin, TypeDefMixin, BuildQuantityMixin, TestCase): """Test building a group that has a nested type def with quantity 2 """ pass class TestBuildZeroOrOneTypeDef(ZeroOrOneMixin, TypeDefMixin, BuildQuantityMixin, TestCase): """Test building a group that has a nested type def with quantity '?' """ pass # Included groups / included datasets / links with quantity {'*', '+', 1, 2, '?'} class TestBuildZeroOrManyTypeInc(ZeroOrManyMixin, TypeIncMixin, BuildQuantityMixin, TestCase): """Test building a group that has a data type inc subgroup/dataset/link with quantity '*' """ pass class TestBuildOneOrManyTypeInc(OneOrManyMixin, TypeIncMixin, BuildQuantityMixin, TestCase): """Test building a group that has a data type inc subgroup/dataset/link with quantity '+' """ pass class TestBuildOneTypeInc(OneMixin, TypeIncMixin, BuildQuantityMixin, TestCase): """Test building a group that has a data type inc subgroup/dataset/link with quantity 1 """ pass class TestBuildTwoTypeInc(TwoMixin, TypeIncMixin, BuildQuantityMixin, TestCase): """Test building a group that has a data type inc subgroup/dataset/link with quantity 2 """ pass class TestBuildZeroOrOneTypeInc(ZeroOrOneMixin, TypeIncMixin, BuildQuantityMixin, TestCase): """Test building a group that has a data type inc subgroup/dataset/link with quantity '?' """ pass # Untyped group/dataset with quantity {1, '?'} class UntypedMixin: def setUpManager(self, specs): spec_catalog = SpecCatalog() schema_file = 'test.yaml' for s in specs: spec_catalog.register_spec(s, schema_file) namespace = SpecNamespace( doc='a test namespace', name=CORE_NAMESPACE, schema=[{'source': schema_file}], version='0.1.0', catalog=spec_catalog ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) type_map.register_container_type(CORE_NAMESPACE, 'BasicBucket', BasicBucket) self.manager = BuildManager(type_map) def create_specs(self, quantity): # Type BasicBucket contains: # - [quantity] untyped group # - [quantity] untyped dataset # - [quantity] untyped array dataset # quantity can be only '?' or 1 untyped_group_spec = GroupSpec( doc='A test group specification with no data type', name='untyped_group', quantity=quantity, ) untyped_dataset_spec = DatasetSpec( doc='A test dataset specification with no data type', name='untyped_dataset', dtype='int', quantity=quantity, ) untyped_array_dataset_spec = DatasetSpec( doc='A test dataset specification with no data type', name='untyped_array_dataset', dtype='int', dims=[None], shape=[None], quantity=quantity, ) basic_bucket_spec = GroupSpec( doc='A test group specification for a data type containing data type', name="test_bucket", data_type_def='BasicBucket', groups=[untyped_group_spec], datasets=[untyped_dataset_spec, untyped_array_dataset_spec], ) return [basic_bucket_spec] class TestBuildOneUntyped(UntypedMixin, TestCase): """Test building a group that has an untyped subgroup/dataset with quantity 1. """ def setUp(self): specs = self.create_specs(DEF_QUANTITY) self.setUpManager(specs) def test_build_data(self): """Test building a container which contains an untyped empty subgroup and an untyped non-empty dataset.""" bucket = BasicBucket(name='test_bucket', untyped_dataset=3, untyped_array_dataset=[3]) # a required untyped empty group builder will be created by default untyped_group_builder = GroupBuilder(name='untyped_group') untyped_dataset_builder = DatasetBuilder(name='untyped_dataset', data=3) untyped_array_dataset_builder = DatasetBuilder(name='untyped_array_dataset', data=[3]) bucket_builder = GroupBuilder( name='test_bucket', groups={'untyped_group': untyped_group_builder}, datasets={'untyped_dataset': untyped_dataset_builder, 'untyped_array_dataset': untyped_array_dataset_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'BasicBucket', 'object_id': bucket.object_id} ) builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_empty_data(self): """Test building a container which contains an untyped empty subgroup and an untyped empty dataset.""" bucket = BasicBucket(name='test_bucket') # a required untyped empty group builder will be created by default untyped_group_builder = GroupBuilder(name='untyped_group') # a required untyped empty dataset builder will NOT be created by default bucket_builder = GroupBuilder( name='test_bucket', groups={'untyped_group': untyped_group_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'BasicBucket', 'object_id': bucket.object_id} ) msg = "BasicBucket 'test_bucket' is missing required value for attribute 'untyped_dataset'." # also raises "BasicBucket 'test_bucket' is missing required value for attribute 'untyped_array_dataset'." with self.assertWarnsWith(MissingRequiredBuildWarning, msg): builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) class TestBuildZeroOrOneUntyped(UntypedMixin, TestCase): """Test building a group that has an untyped subgroup/dataset with quantity '?'. """ def setUp(self): specs = self.create_specs(ZERO_OR_ONE) self.setUpManager(specs) def test_build_data(self): """Test building a container which contains an untyped empty subgroup and an untyped non-empty dataset.""" bucket = BasicBucket(name='test_bucket', untyped_dataset=3, untyped_array_dataset=[3]) # an optional untyped empty group builder will NOT be created by default untyped_dataset_builder = DatasetBuilder(name='untyped_dataset', data=3) untyped_array_dataset_builder = DatasetBuilder(name='untyped_array_dataset', data=[3]) bucket_builder = GroupBuilder( name='test_bucket', datasets={'untyped_dataset': untyped_dataset_builder, 'untyped_array_dataset': untyped_array_dataset_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'BasicBucket', 'object_id': bucket.object_id} ) builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) def test_build_empty_data(self): """Test building a container which contains an untyped empty subgroup and an untyped empty dataset.""" bucket = BasicBucket(name='test_bucket') # an optional untyped empty group builder will NOT be created by default # an optional untyped empty dataset builder will NOT be created by default bucket_builder = GroupBuilder( name='test_bucket', attributes={'namespace': CORE_NAMESPACE, 'data_type': 'BasicBucket', 'object_id': bucket.object_id} ) builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) # Multiple allowed types class TestBuildMultiTypeInc(BuildQuantityMixin, TestCase): """Test build process when a groupspec allows multiple groups/datasets/links with different data types / targets. """ def setUp(self): specs = self.create_specs(ZERO_OR_MANY) self.setUpManager(specs) def create_specs(self, quantity): # Type SimpleBucket contains: # - [quantity] groups of data_type_inc SimpleFoo and [quantity] group of data_type_inc NotSimpleFoo # - [quantity] datasets of data_type_inc SimpleQux and [quantity] datasets of data_type_inc NotSimpleQux # - [quantity] links of target_type SimpleFoo and [quantity] links of target_type NotSimpleFoo foo_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='SimpleFoo', ) not_foo_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='NotSimpleFoo', ) qux_spec = DatasetSpec( doc='A test group specification with a data type', data_type_def='SimpleQux', ) not_qux_spec = DatasetSpec( doc='A test group specification with a data type', data_type_def='NotSimpleQux', ) foo_inc_spec = GroupSpec( doc='the SimpleFoos in this bucket', data_type_inc='SimpleFoo', quantity=quantity ) not_foo_inc_spec = GroupSpec( doc='the SimpleFoos in this bucket', data_type_inc='NotSimpleFoo', quantity=quantity ) qux_inc_spec = DatasetSpec( doc='the SimpleQuxs in this bucket', data_type_inc='SimpleQux', quantity=quantity ) not_qux_inc_spec = DatasetSpec( doc='the SimpleQuxs in this bucket', data_type_inc='NotSimpleQux', quantity=quantity ) foo_link_spec = LinkSpec( doc='the links in this bucket', target_type='SimpleFoo', quantity=quantity ) not_foo_link_spec = LinkSpec( doc='the links in this bucket', target_type='NotSimpleFoo', quantity=quantity ) bucket_spec = GroupSpec( doc='A test group specification for a data type containing data type', name="test_bucket", data_type_def='SimpleBucket', groups=[foo_inc_spec, not_foo_inc_spec], datasets=[qux_inc_spec, not_qux_inc_spec], links=[foo_link_spec, not_foo_link_spec] ) return [foo_spec, not_foo_spec, qux_spec, not_qux_spec, bucket_spec] def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) self.map_spec('foos', spec.get_data_type('SimpleFoo')) self.map_spec('foos', spec.get_data_type('NotSimpleFoo')) self.map_spec('quxs', spec.get_data_type('SimpleQux')) self.map_spec('quxs', spec.get_data_type('NotSimpleQux')) self.map_spec('links', spec.links[0]) self.map_spec('links', spec.links[1]) return BucketMapper def get_two_bucket_test(self): foos = [SimpleFoo('my_foo1'), NotSimpleFoo('my_foo2')] quxs = [SimpleQux('my_qux1', data=[1, 2, 3]), NotSimpleQux('my_qux2', data=[4, 5, 6])] # NOTE: unlike in the other tests, links cannot map to the same foos in bucket because of a name clash links = [SimpleFoo('my_foo3'), NotSimpleFoo('my_foo4')] bucket = SimpleBucket( name='test_bucket', foos=foos, quxs=quxs, links=links ) foo1_builder = self._create_builder(bucket.foos['my_foo1']) foo2_builder = self._create_builder(bucket.foos['my_foo2']) foo3_builder = self._create_builder(bucket.links['my_foo3']) foo4_builder = self._create_builder(bucket.links['my_foo4']) qux1_builder = self._create_builder(bucket.quxs['my_qux1']) qux2_builder = self._create_builder(bucket.quxs['my_qux2']) foo3_link_builder = LinkBuilder(builder=foo3_builder) foo4_link_builder = LinkBuilder(builder=foo4_builder) bucket_builder = GroupBuilder( name='test_bucket', groups={'my_foo1': foo1_builder, 'my_foo2': foo2_builder}, datasets={'my_qux1': qux1_builder, 'my_qux2': qux2_builder}, links={'my_foo3': foo3_link_builder, 'my_foo4': foo4_link_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'SimpleBucket', 'object_id': bucket.object_id} ) return bucket, bucket_builder def test_build_two(self): """Test building a container which contains multiple containers of different types as the spec allows.""" bucket, bucket_builder = self.get_two_bucket_test() builder = self.manager.build(bucket) self.assertDictEqual(builder, bucket_builder) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/test_builder.py0000644000655200065520000003351200000000000022266 0ustar00circlecicirclecifrom hdmf.build import GroupBuilder, DatasetBuilder, LinkBuilder, ReferenceBuilder, RegionBuilder from hdmf.testing import TestCase class TestGroupBuilder(TestCase): def test_constructor(self): gb1 = GroupBuilder('gb1', source='source') gb2 = GroupBuilder('gb2', parent=gb1) self.assertIs(gb1.name, 'gb1') self.assertIsNone(gb1.parent) self.assertEqual(gb1.source, 'source') self.assertIs(gb2.parent, gb1) def test_repr(self): gb1 = GroupBuilder('gb1') expected = "gb1 GroupBuilder {'attributes': {}, 'groups': {}, 'datasets': {}, 'links': {}}" self.assertEqual(gb1.__repr__(), expected) def test_set_source(self): """Test that setting source sets the children builder source.""" gb1 = GroupBuilder('gb1') db = DatasetBuilder('db', list(range(10))) lb = LinkBuilder(gb1, 'lb') gb2 = GroupBuilder('gb1', {'gb1': gb1}, {'db': db}, {}, {'lb': lb}) gb2.source = 'source' self.assertEqual(gb2.source, 'source') self.assertEqual(gb1.source, 'source') self.assertEqual(db.source, 'source') self.assertEqual(lb.source, 'source') def test_set_source_no_reset(self): """Test that setting source does not set the children builder source if children already have a source.""" gb1 = GroupBuilder('gb1', source='original') db = DatasetBuilder('db', list(range(10)), source='original') lb = LinkBuilder(gb1, 'lb', source='original') gb2 = GroupBuilder('gb1', {'gb1': gb1}, {'db': db}, {}, {'lb': lb}) gb2.source = 'source' self.assertEqual(gb1.source, 'original') self.assertEqual(db.source, 'original') self.assertEqual(lb.source, 'original') def test_constructor_dset_none(self): gb1 = GroupBuilder('gb1', datasets={'empty': None}) self.assertEqual(len(gb1.datasets), 0) def test_set_location(self): gb1 = GroupBuilder('gb1') gb1.location = 'location' self.assertEqual(gb1.location, 'location') def test_overwrite_location(self): gb1 = GroupBuilder('gb1') gb1.location = 'location' gb1.location = 'new location' self.assertEqual(gb1.location, 'new location') class TestGroupBuilderSetters(TestCase): def test_set_attribute(self): gb = GroupBuilder('gb') gb.set_attribute('key', 'value') self.assertIn('key', gb.obj_type) self.assertIn('key', gb.attributes) self.assertEqual(gb['key'], 'value') def test_set_group(self): gb1 = GroupBuilder('gb1') gb2 = GroupBuilder('gb2') gb1.set_group(gb2) self.assertIs(gb2.parent, gb1) self.assertIn('gb2', gb1.obj_type) self.assertIn('gb2', gb1.groups) self.assertIs(gb1['gb2'], gb2) def test_set_dataset(self): gb = GroupBuilder('gb') db = DatasetBuilder('db', list(range(10))) gb.set_dataset(db) self.assertIs(db.parent, gb) self.assertIn('db', gb.obj_type) self.assertIn('db', gb.datasets) self.assertIs(gb['db'], db) def test_set_link(self): gb1 = GroupBuilder('gb1') gb2 = GroupBuilder('gb2') lb = LinkBuilder(gb2) gb1.set_link(lb) self.assertIs(lb.parent, gb1) self.assertIn('gb2', gb1.obj_type) self.assertIn('gb2', gb1.links) self.assertIs(gb1['gb2'], lb) def test_setitem_disabled(self): """Test __setitem__ is disabled""" gb = GroupBuilder('gb') with self.assertRaises(NotImplementedError): gb['key'] = 'value' def test_set_exists_wrong_type(self): gb1 = GroupBuilder('gb1') gb2 = GroupBuilder('gb2') db = DatasetBuilder('gb2') gb1.set_group(gb2) msg = "'gb2' already exists in gb1.groups, cannot set in datasets." with self.assertRaisesWith(ValueError, msg): gb1.set_dataset(db) class TestGroupBuilderGetters(TestCase): def setUp(self): self.subgroup1 = GroupBuilder('subgroup1') self.dataset1 = DatasetBuilder('dataset1', list(range(10))) self.link1 = LinkBuilder(self.subgroup1, 'link1') self.int_attr = 1 self.str_attr = "my_str" self.group1 = GroupBuilder('group1', {'subgroup1': self.subgroup1}) self.gb = GroupBuilder( name='gb', groups={'group1': self.group1}, datasets={'dataset1': self.dataset1}, attributes={'int_attr': self.int_attr, 'str_attr': self.str_attr}, links={'link1': self.link1} ) def test_path(self): self.assertEqual(self.subgroup1.path, 'gb/group1/subgroup1') self.assertEqual(self.dataset1.path, 'gb/dataset1') self.assertEqual(self.link1.path, 'gb/link1') self.assertEqual(self.group1.path, 'gb/group1') self.assertEqual(self.gb.path, 'gb') def test_getitem_group(self): """Test __getitem__ for groups""" self.assertIs(self.gb['group1'], self.group1) def test_getitem_group_deeper(self): """Test __getitem__ for groups deeper in hierarchy""" self.assertIs(self.gb['group1/subgroup1'], self.subgroup1) def test_getitem_dataset(self): """Test __getitem__ for datasets""" self.assertIs(self.gb['dataset1'], self.dataset1) def test_getitem_attr(self): """Test __getitem__ for attributes""" self.assertEqual(self.gb['int_attr'], self.int_attr) self.assertEqual(self.gb['str_attr'], self.str_attr) def test_getitem_invalid_key(self): """Test __getitem__ for invalid key""" with self.assertRaises(KeyError): self.gb['invalid_key'] def test_getitem_invalid_key_deeper(self): """Test __getitem__ for invalid key""" with self.assertRaises(KeyError): self.gb['group/invalid_key'] def test_getitem_link(self): """Test __getitem__ for links""" self.assertIs(self.gb['link1'], self.link1) def test_get_group(self): """Test get() for groups""" self.assertIs(self.gb.get('group1'), self.group1) def test_get_group_deeper(self): """Test get() for groups deeper in hierarchy""" self.assertIs(self.gb.get('group1/subgroup1'), self.subgroup1) def test_get_dataset(self): """Test get() for datasets""" self.assertIs(self.gb.get('dataset1'), self.dataset1) def test_get_attr(self): """Test get() for attributes""" self.assertEqual(self.gb.get('int_attr'), self.int_attr) self.assertEqual(self.gb.get('str_attr'), self.str_attr) def test_get_link(self): """Test get() for links""" self.assertIs(self.gb.get('link1'), self.link1) def test_get_invalid_key(self): """Test get() for invalid key""" self.assertIs(self.gb.get('invalid_key'), None) def test_items(self): """Test items()""" items = ( ('group1', self.group1), ('dataset1', self.dataset1), ('int_attr', self.int_attr), ('str_attr', self.str_attr), ('link1', self.link1), ) # self.assertSetEqual(items, set(self.gb.items())) try: self.assertCountEqual(items, self.gb.items()) except AttributeError: self.assertItemsEqual(items, self.gb.items()) def test_keys(self): """Test keys()""" keys = ( 'group1', 'dataset1', 'int_attr', 'str_attr', 'link1', ) try: self.assertCountEqual(keys, self.gb.keys()) except AttributeError: self.assertItemsEqual(keys, self.gb.keys()) def test_values(self): """Test values()""" values = ( self.group1, self.dataset1, self.int_attr, self.str_attr, self.link1, ) try: self.assertCountEqual(values, self.gb.values()) except AttributeError: self.assertItemsEqual(values, self.gb.values()) class TestGroupBuilderIsEmpty(TestCase): def test_is_empty_true(self): """Test empty when group has nothing in it""" gb = GroupBuilder('gb') self.assertTrue(gb.is_empty()) def test_is_empty_true_group_empty(self): """Test is_empty() when group has an empty subgroup""" gb1 = GroupBuilder('my_subgroup') gb2 = GroupBuilder('gb', {'my_subgroup': gb1}) self.assertTrue(gb2.is_empty()) def test_is_empty_false_dataset(self): """Test is_empty() when group has a dataset""" gb = GroupBuilder('gb', datasets={'my_dataset': DatasetBuilder('my_dataset')}) self.assertFalse(gb.is_empty()) def test_is_empty_false_group_dataset(self): """Test is_empty() when group has a subgroup with a dataset""" gb1 = GroupBuilder('my_subgroup', datasets={'my_dataset': DatasetBuilder('my_dataset')}) gb2 = GroupBuilder('gb', {'my_subgroup': gb1}) self.assertFalse(gb2.is_empty()) def test_is_empty_false_attribute(self): """Test is_empty() when group has an attribute""" gb = GroupBuilder('gb', attributes={'my_attr': 'attr_value'}) self.assertFalse(gb.is_empty()) def test_is_empty_false_group_attribute(self): """Test is_empty() when group has subgroup with an attribute""" gb1 = GroupBuilder('my_subgroup', attributes={'my_attr': 'attr_value'}) gb2 = GroupBuilder('gb', {'my_subgroup': gb1}) self.assertFalse(gb2.is_empty()) def test_is_empty_false_link(self): """Test is_empty() when group has a link""" gb1 = GroupBuilder('target') gb2 = GroupBuilder('gb', links={'my_link': LinkBuilder(gb1)}) self.assertFalse(gb2.is_empty()) def test_is_empty_false_group_link(self): """Test is_empty() when group has subgroup with a link""" gb1 = GroupBuilder('target') gb2 = GroupBuilder('my_subgroup', links={'my_link': LinkBuilder(gb1)}) gb3 = GroupBuilder('gb', {'my_subgroup': gb2}) self.assertFalse(gb3.is_empty()) class TestDatasetBuilder(TestCase): def test_constructor(self): gb1 = GroupBuilder('gb1') db1 = DatasetBuilder( name='db1', data=[1, 2, 3], dtype=int, attributes={'attr1': 10}, maxshape=10, chunks=True, parent=gb1, source='source', ) self.assertEqual(db1.name, 'db1') self.assertListEqual(db1.data, [1, 2, 3]) self.assertEqual(db1.dtype, int) self.assertDictEqual(db1.attributes, {'attr1': 10}) self.assertEqual(db1.maxshape, 10) self.assertTrue(db1.chunks) self.assertIs(db1.parent, gb1) self.assertEqual(db1.source, 'source') def test_constructor_data_builder_no_dtype(self): db1 = DatasetBuilder(name='db1', dtype=int) db2 = DatasetBuilder(name='db2', data=db1) self.assertEqual(db2.dtype, DatasetBuilder.OBJECT_REF_TYPE) def test_constructor_data_builder_dtype(self): db1 = DatasetBuilder(name='db1', dtype=int) db2 = DatasetBuilder(name='db2', data=db1, dtype=float) self.assertEqual(db2.dtype, float) def test_set_data(self): db1 = DatasetBuilder(name='db1') db1.data = [4, 5, 6] self.assertEqual(db1.data, [4, 5, 6]) def test_set_dtype(self): db1 = DatasetBuilder(name='db1') db1.dtype = float self.assertEqual(db1.dtype, float) def test_overwrite_data(self): db1 = DatasetBuilder(name='db1', data=[1, 2, 3]) msg = "Cannot overwrite data." with self.assertRaisesWith(AttributeError, msg): db1.data = [4, 5, 6] def test_overwrite_dtype(self): db1 = DatasetBuilder(name='db1', data=[1, 2, 3], dtype=int) msg = "Cannot overwrite dtype." with self.assertRaisesWith(AttributeError, msg): db1.dtype = float def test_overwrite_source(self): db1 = DatasetBuilder(name='db1', data=[1, 2, 3], source='source') msg = 'Cannot overwrite source.' with self.assertRaisesWith(AttributeError, msg): db1.source = 'new source' def test_overwrite_parent(self): gb1 = GroupBuilder('gb1') db1 = DatasetBuilder(name='db1', data=[1, 2, 3], parent=gb1) msg = 'Cannot overwrite parent.' with self.assertRaisesWith(AttributeError, msg): db1.parent = gb1 def test_repr(self): gb1 = GroupBuilder('gb1') db1 = DatasetBuilder( name='db1', data=[1, 2, 3], dtype=int, attributes={'attr2': 10}, maxshape=10, chunks=True, parent=gb1, source='source', ) expected = "gb1/db1 DatasetBuilder {'attributes': {'attr2': 10}, 'data': [1, 2, 3]}" self.assertEqual(db1.__repr__(), expected) class TestLinkBuilder(TestCase): def test_constructor(self): gb = GroupBuilder('gb1') db = DatasetBuilder('db1', [1, 2, 3]) lb = LinkBuilder(db, 'link_name', gb, 'link_source') self.assertIs(lb.builder, db) self.assertEqual(lb.name, 'link_name') self.assertIs(lb.parent, gb) self.assertEqual(lb.source, 'link_source') def test_constructor_no_name(self): db = DatasetBuilder('db1', [1, 2, 3]) lb = LinkBuilder(db) self.assertIs(lb.builder, db) self.assertEqual(lb.name, 'db1') class TestReferenceBuilder(TestCase): def test_constructor(self): db = DatasetBuilder('db1', [1, 2, 3]) rb = ReferenceBuilder(db) self.assertIs(rb.builder, db) class TestRegionBuilder(TestCase): def test_constructor(self): db = DatasetBuilder('db1', [1, 2, 3]) rb = RegionBuilder(slice(1, 3), db) self.assertEqual(rb.region, slice(1, 3)) self.assertIs(rb.builder, db) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/test_classgenerator.py0000644000655200065520000012142300000000000023653 0ustar00circlecicircleciimport numpy as np import os import shutil import tempfile from hdmf.build import TypeMap, CustomClassGenerator from hdmf.build.classgenerator import ClassGenerator, MCIClassGenerator from hdmf.container import Container, Data, MultiContainerInterface, AbstractContainer from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, LinkSpec from hdmf.testing import TestCase from hdmf.utils import get_docval from .test_io_map import Bar from tests.unit.utils import CORE_NAMESPACE, create_test_type_map, create_load_namespace_yaml class TestClassGenerator(TestCase): def test_register_generator(self): """Test TypeMap.register_generator and ClassGenerator.register_generator.""" class MyClassGenerator(CustomClassGenerator): @classmethod def apply_generator_to_field(cls, field_spec, bases, type_map): return True @classmethod def process_field_spec(cls, classdict, docval_args, parent_cls, attr_name, not_inherited_fields, type_map, spec): # append attr_name to classdict['__custom_fields__'] list classdict.setdefault('process_field_spec', list()).append(attr_name) @classmethod def post_process(cls, classdict, bases, docval_args, spec): classdict['post_process'] = True spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', attributes=[ AttributeSpec(name='attr1', doc='a string attribute', dtype='text') ] ) spec_catalog = SpecCatalog() spec_catalog.register_spec(spec, 'test.yaml') namespace = SpecNamespace( doc='a test namespace', name=CORE_NAMESPACE, schema=[{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) type_map.register_generator(MyClassGenerator) cls = type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) self.assertEqual(cls.process_field_spec, ['attr1']) self.assertTrue(cls.post_process) def test_bad_generator(self): """Test that register_generator raises an error if the generator is not an instance of CustomClassGenerator.""" class NotACustomClassGenerator: pass type_map = TypeMap() msg = 'Generator <.*> must be a subclass of CustomClassGenerator.' with self.assertRaisesRegex(ValueError, msg): type_map.register_generator(NotACustomClassGenerator) def test_no_generators(self): """Test that a ClassGenerator without registered generators does nothing.""" cg = ClassGenerator() spec = GroupSpec(doc='A test group spec with a data type', data_type_def='Baz') cls = cg.generate_class(data_type='Baz', spec=spec, parent_cls=Container, attr_names={}, type_map=TypeMap()) self.assertEqual(cls.__mro__, (cls, Container, AbstractContainer, object)) self.assertTrue(hasattr(cls, '__init__')) class TestDynamicContainer(TestCase): def setUp(self): self.bar_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Bar', datasets=[ DatasetSpec( doc='a dataset', dtype='int', name='data', attributes=[AttributeSpec(name='attr2', doc='an integer attribute', dtype='int')] ) ], attributes=[AttributeSpec(name='attr1', doc='a string attribute', dtype='text')]) specs = [self.bar_spec] containers = {'Bar': Bar} self.type_map = create_test_type_map(specs, containers) self.spec_catalog = self.type_map.namespace_catalog.get_namespace(CORE_NAMESPACE).catalog def test_dynamic_container_creation(self): baz_spec = GroupSpec('A test extension with no Container class', data_type_def='Baz', data_type_inc=self.bar_spec, attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4'} received_args = set() for x in get_docval(cls.__init__): if x['name'] != 'foo': received_args.add(x['name']) with self.subTest(name=x['name']): self.assertNotIn('default', x) self.assertSetEqual(expected_args, received_args) self.assertEqual(cls.__name__, 'Baz') self.assertTrue(issubclass(cls, Bar)) def test_dynamic_container_default_name(self): baz_spec = GroupSpec('doc', default_name='bingo', data_type_def='Baz', attributes=[AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) inst = cls(attr4=10.) self.assertEqual(inst.name, 'bingo') def test_dynamic_container_creation_defaults(self): baz_spec = GroupSpec('A test extension with no Container class', data_type_def='Baz', data_type_inc=self.bar_spec, attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) expected_args = {'name', 'data', 'attr1', 'attr2', 'attr3', 'attr4', 'foo'} received_args = set(map(lambda x: x['name'], get_docval(cls.__init__))) self.assertSetEqual(expected_args, received_args) self.assertEqual(cls.__name__, 'Baz') self.assertTrue(issubclass(cls, Bar)) def test_dynamic_container_constructor(self): baz_spec = GroupSpec('A test extension with no Container class', data_type_def='Baz', data_type_inc=self.bar_spec, attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) # TODO: test that constructor works! inst = cls('My Baz', [1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0) self.assertEqual(inst.name, 'My Baz') self.assertEqual(inst.data, [1, 2, 3, 4]) self.assertEqual(inst.attr1, 'string attribute') self.assertEqual(inst.attr2, 1000) self.assertEqual(inst.attr3, 98.6) self.assertEqual(inst.attr4, 1.0) def test_dynamic_container_constructor_name(self): # name is specified in spec and cannot be changed baz_spec = GroupSpec('A test extension with no Container class', data_type_def='Baz', data_type_inc=self.bar_spec, name='A fixed name', attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) with self.assertRaises(TypeError): inst = cls('My Baz', [1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0) inst = cls([1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0) self.assertEqual(inst.name, 'A fixed name') self.assertEqual(inst.data, [1, 2, 3, 4]) self.assertEqual(inst.attr1, 'string attribute') self.assertEqual(inst.attr2, 1000) self.assertEqual(inst.attr3, 98.6) self.assertEqual(inst.attr4, 1.0) def test_dynamic_container_constructor_name_default_name(self): # if both name and default_name are specified, name should be used with self.assertWarns(Warning): baz_spec = GroupSpec('A test extension with no Container class', data_type_def='Baz', data_type_inc=self.bar_spec, name='A fixed name', default_name='A default name', attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')]) self.spec_catalog.register_spec(baz_spec, 'extension.yaml') cls = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) inst = cls([1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0) self.assertEqual(inst.name, 'A fixed name') def test_dynamic_container_composition(self): baz_spec2 = GroupSpec('A composition inside', data_type_def='Baz2', data_type_inc=self.bar_spec, attributes=[ AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')]) baz_spec1 = GroupSpec('A composition test outside', data_type_def='Baz1', data_type_inc=self.bar_spec, attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')], groups=[GroupSpec('A composition inside', data_type_inc='Baz2')]) self.spec_catalog.register_spec(baz_spec1, 'extension.yaml') self.spec_catalog.register_spec(baz_spec2, 'extension.yaml') Baz2 = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE) Baz1 = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) Baz1('My Baz', [1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0, baz2=Baz2('My Baz', [1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0)) Bar = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) bar = Bar('My Bar', [1, 2, 3, 4], 'string attribute', 1000) with self.assertRaises(TypeError): Baz1('My Baz', [1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0, baz2=bar) def test_dynamic_container_composition_reverse_order(self): baz_spec2 = GroupSpec('A composition inside', data_type_def='Baz2', data_type_inc=self.bar_spec, attributes=[ AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')]) baz_spec1 = GroupSpec('A composition test outside', data_type_def='Baz1', data_type_inc=self.bar_spec, attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')], groups=[GroupSpec('A composition inside', data_type_inc='Baz2')]) self.spec_catalog.register_spec(baz_spec1, 'extension.yaml') self.spec_catalog.register_spec(baz_spec2, 'extension.yaml') Baz1 = self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) Baz2 = self.type_map.get_dt_container_cls('Baz2', CORE_NAMESPACE) Baz1('My Baz', [1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0, baz2=Baz2('My Baz', [1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0)) Bar = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) bar = Bar('My Bar', [1, 2, 3, 4], 'string attribute', 1000) with self.assertRaises(TypeError): Baz1('My Baz', [1, 2, 3, 4], 'string attribute', 1000, attr3=98.6, attr4=1.0, baz2=bar) def test_dynamic_container_composition_missing_type(self): baz_spec1 = GroupSpec('A composition test outside', data_type_def='Baz1', data_type_inc=self.bar_spec, attributes=[AttributeSpec('attr3', 'a float attribute', 'float'), AttributeSpec('attr4', 'another float attribute', 'float')], groups=[GroupSpec('A composition inside', data_type_inc='Baz2')]) self.spec_catalog.register_spec(baz_spec1, 'extension.yaml') msg = "No specification for 'Baz2' in namespace 'test_core'" with self.assertRaisesWith(ValueError, msg): self.type_map.get_dt_container_cls('Baz1', CORE_NAMESPACE) def test_dynamic_container_fixed_name(self): """Test that dynamic class generation for an extended type with a fixed name works.""" baz_spec = GroupSpec('A test extension with no Container class', data_type_def='Baz', data_type_inc=self.bar_spec, name='Baz') self.spec_catalog.register_spec(baz_spec, 'extension.yaml') Baz = self.type_map.get_dt_container_cls('Baz', CORE_NAMESPACE) obj = Baz([1, 2, 3, 4], 'string attribute', attr2=1000) self.assertEqual(obj.name, 'Baz') def test_multi_container_spec(self): multi_spec = GroupSpec( doc='A test extension that contains a multi', data_type_def='Multi', groups=[ GroupSpec(data_type_inc=self.bar_spec, doc='test multi', quantity='*') ], attributes=[ AttributeSpec(name='attr3', doc='a float attribute', dtype='float') ] ) self.spec_catalog.register_spec(multi_spec, 'extension.yaml') Bar = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) Multi = self.type_map.get_dt_container_cls('Multi', CORE_NAMESPACE) assert issubclass(Multi, MultiContainerInterface) assert Multi.__clsconf__ == [ dict( attr='bars', type=Bar, add='add_bars', get='get_bars', create='create_bars' ) ] multi = Multi( name='my_multi', bars=[Bar('my_bar', list(range(10)), 'value1', 10)], attr3=5. ) assert multi.bars['my_bar'] == Bar('my_bar', list(range(10)), 'value1', 10) assert multi.attr3 == 5. class TestGetClassSeparateNamespace(TestCase): def setUp(self): self.test_dir = tempfile.mkdtemp() if os.path.exists(self.test_dir): # start clean self.tearDown() os.mkdir(self.test_dir) self.bar_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Bar', datasets=[ DatasetSpec(name='data', doc='a dataset', dtype='int') ], attributes=[ AttributeSpec(name='attr1', doc='a string attribute', dtype='text'), AttributeSpec(name='attr2', doc='an integer attribute', dtype='int') ] ) self.type_map = TypeMap() create_load_namespace_yaml( namespace_name=CORE_NAMESPACE, specs=[self.bar_spec], output_dir=self.test_dir, incl_types=dict(), type_map=self.type_map ) def tearDown(self): shutil.rmtree(self.test_dir) def test_get_class_separate_ns(self): """Test that get_class correctly sets the name and type hierarchy across namespaces.""" self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) baz_spec = GroupSpec( doc='A test extension', data_type_def='Baz', data_type_inc='Bar', ) create_load_namespace_yaml( namespace_name='ndx-test', specs=[baz_spec], output_dir=self.test_dir, incl_types={CORE_NAMESPACE: ['Bar']}, type_map=self.type_map ) cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') self.assertEqual(cls.__name__, 'Baz') self.assertTrue(issubclass(cls, Bar)) def _build_separate_namespaces(self): # create an empty extension to test ClassGenerator._get_container_type resolution # the Bar class has not been mapped yet to the bar spec qux_spec = DatasetSpec( doc='A test extension', data_type_def='Qux' ) spam_spec = DatasetSpec( doc='A test extension', data_type_def='Spam' ) create_load_namespace_yaml( namespace_name='ndx-qux', specs=[qux_spec, spam_spec], output_dir=self.test_dir, incl_types={}, type_map=self.type_map ) # resolve Spam first so that ndx-qux is resolved first self.type_map.get_dt_container_cls('Spam', 'ndx-qux') baz_spec = GroupSpec( doc='A test extension', data_type_def='Baz', data_type_inc='Bar', groups=[ GroupSpec(data_type_inc='Qux', doc='a qux', quantity='?'), GroupSpec(data_type_inc='Bar', doc='a bar', quantity='?') ] ) create_load_namespace_yaml( namespace_name='ndx-test', specs=[baz_spec], output_dir=self.test_dir, incl_types={ CORE_NAMESPACE: ['Bar'], 'ndx-qux': ['Qux'] }, type_map=self.type_map ) def _check_classes(self, baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2): self.assertEqual(qux_cls.__name__, 'Qux') self.assertEqual(baz_cls.__name__, 'Baz') self.assertEqual(bar_cls.__name__, 'Bar') self.assertIs(bar_cls, bar_cls2) # same class, two different namespaces self.assertIs(qux_cls, qux_cls2) self.assertTrue(issubclass(qux_cls, Data)) self.assertTrue(issubclass(baz_cls, bar_cls)) self.assertTrue(issubclass(bar_cls, Container)) qux_inst = qux_cls(name='qux_name', data=[1]) bar_inst = bar_cls(name='bar_name', data=100, attr1='a string', attr2=10) baz_inst = baz_cls(name='baz_name', qux=qux_inst, bar=bar_inst, data=100, attr1='a string', attr2=10) self.assertIs(baz_inst.qux, qux_inst) def test_get_class_include_from_separate_ns_1(self): """Test that get_class correctly sets the name and includes types correctly across namespaces. This is one of multiple tests carried out to ensure that order of which get_dt_container_cls is called does not impact the results first use EXTENSION namespace, then use ORIGINAL namespace """ self._build_separate_namespaces() baz_cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') # Qux and Bar are not yet resolved bar_cls = self.type_map.get_dt_container_cls('Bar', 'ndx-test') bar_cls2 = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) qux_cls = self.type_map.get_dt_container_cls('Qux', 'ndx-test') qux_cls2 = self.type_map.get_dt_container_cls('Qux', 'ndx-qux') self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2) def test_get_class_include_from_separate_ns_2(self): """Test that get_class correctly sets the name and includes types correctly across namespaces. This is one of multiple tests carried out to ensure that order of which get_dt_container_cls is called does not impact the results first use ORIGINAL namespace, then use EXTENSION namespace """ self._build_separate_namespaces() baz_cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') # Qux and Bar are not yet resolved bar_cls2 = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) bar_cls = self.type_map.get_dt_container_cls('Bar', 'ndx-test') qux_cls = self.type_map.get_dt_container_cls('Qux', 'ndx-test') qux_cls2 = self.type_map.get_dt_container_cls('Qux', 'ndx-qux') self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2) def test_get_class_include_from_separate_ns_3(self): """Test that get_class correctly sets the name and includes types correctly across namespaces. This is one of multiple tests carried out to ensure that order of which get_dt_container_cls is called does not impact the results first use EXTENSION namespace, then use EXTENSION namespace """ self._build_separate_namespaces() baz_cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') # Qux and Bar are not yet resolved bar_cls = self.type_map.get_dt_container_cls('Bar', 'ndx-test') bar_cls2 = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) qux_cls2 = self.type_map.get_dt_container_cls('Qux', 'ndx-qux') qux_cls = self.type_map.get_dt_container_cls('Qux', 'ndx-test') self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2) def test_get_class_include_from_separate_ns_4(self): """Test that get_class correctly sets the name and includes types correctly across namespaces. This is one of multiple tests carried out to ensure that order of which get_dt_container_cls is called does not impact the results first use ORIGINAL namespace, then use EXTENSION namespace """ self._build_separate_namespaces() baz_cls = self.type_map.get_dt_container_cls('Baz', 'ndx-test') # Qux and Bar are not yet resolved bar_cls2 = self.type_map.get_dt_container_cls('Bar', CORE_NAMESPACE) bar_cls = self.type_map.get_dt_container_cls('Bar', 'ndx-test') qux_cls2 = self.type_map.get_dt_container_cls('Qux', 'ndx-qux') qux_cls = self.type_map.get_dt_container_cls('Qux', 'ndx-test') self._check_classes(baz_cls, bar_cls, bar_cls2, qux_cls, qux_cls2) class EmptyBar(Container): pass class TestBaseProcessFieldSpec(TestCase): def setUp(self): self.bar_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='EmptyBar' ) self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'EmptyBar', EmptyBar) def test_update_docval(self): """Test update_docval_args for a variety of data types and mapping configurations.""" spec = GroupSpec( doc="A test group specification with a data type", data_type_def="Baz", groups=[ GroupSpec(doc="a group", data_type_inc="EmptyBar", quantity="?") ], datasets=[ DatasetSpec( doc="a dataset", dtype="int", name="data", attributes=[ AttributeSpec(name="attr2", doc="an integer attribute", dtype="int") ], ) ], attributes=[ AttributeSpec(name="attr1", doc="a string attribute", dtype="text"), AttributeSpec(name="attr3", doc="a numeric attribute", dtype="numeric"), AttributeSpec(name="attr4", doc="a float attribute", dtype="float"), ], ) expected = [ {'name': 'data', 'type': (int, np.int32, np.int64), 'doc': 'a dataset'}, {'name': 'attr1', 'type': str, 'doc': 'a string attribute'}, {'name': 'attr2', 'type': (int, np.int32, np.int64), 'doc': 'an integer attribute'}, {'name': 'attr3', 'doc': 'a numeric attribute', 'type': (float, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, int, np.uint8, np.uint16, np.uint32, np.uint64)}, {'name': 'attr4', 'doc': 'a float attribute', 'type': (float, np.float32, np.float64)}, {'name': 'bar', 'type': EmptyBar, 'doc': 'a group', 'default': None}, ] not_inherited_fields = { 'data': spec.get_dataset('data'), 'attr1': spec.get_attribute('attr1'), 'attr2': spec.get_dataset('data').get_attribute('attr2'), 'attr3': spec.get_attribute('attr3'), 'attr4': spec.get_attribute('attr4'), 'bar': spec.groups[0] } docval_args = list() for i, attr_name in enumerate(not_inherited_fields): with self.subTest(attr_name=attr_name): CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class attr_name=attr_name, not_inherited_fields=not_inherited_fields, type_map=self.type_map, spec=spec ) self.assertListEqual(docval_args, expected[:(i+1)]) # compare with the first i elements of expected def test_update_docval_attr_shape(self): """Test that update_docval_args for an attribute with shape sets the type and shape keys.""" spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', attributes=[ AttributeSpec(name='attr1', doc='a string attribute', dtype='text', shape=[None]) ] ) not_inherited_fields = {'attr1': spec.get_attribute('attr1')} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, type_map=TypeMap(), spec=spec ) expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [None]}] self.assertListEqual(docval_args, expected) def test_update_docval_dset_shape(self): """Test that update_docval_args for a dataset with shape sets the type and shape keys.""" spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', datasets=[ DatasetSpec(name='dset1', doc='a string dataset', dtype='text', shape=[None]) ] ) not_inherited_fields = {'dset1': spec.get_dataset('dset1')} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class attr_name='dset1', not_inherited_fields=not_inherited_fields, type_map=TypeMap(), spec=spec ) expected = [{'name': 'dset1', 'type': ('array_data', 'data'), 'doc': 'a string dataset', 'shape': [None]}] self.assertListEqual(docval_args, expected) def test_update_docval_default_value(self): """Test that update_docval_args for an optional field with default value sets the default key.""" spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', attributes=[ AttributeSpec(name='attr1', doc='a string attribute', dtype='text', required=False, default_value='value') ] ) not_inherited_fields = {'attr1': spec.get_attribute('attr1')} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, type_map=TypeMap(), spec=spec ) expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': 'value'}] self.assertListEqual(docval_args, expected) def test_update_docval_default_value_none(self): """Test that update_docval_args for an optional field sets default: None.""" spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', attributes=[ AttributeSpec(name='attr1', doc='a string attribute', dtype='text', required=False) ] ) not_inherited_fields = {'attr1': spec.get_attribute('attr1')} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, type_map=TypeMap(), spec=spec ) expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}] self.assertListEqual(docval_args, expected) def test_update_docval_default_value_none_required_parent(self): """Test that update_docval_args for an optional field with a required parent sets default: None.""" spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', groups=[ GroupSpec( name='group1', doc='required untyped group', attributes=[ AttributeSpec(name='attr1', doc='a string attribute', dtype='text', required=False) ] ) ] ) not_inherited_fields = {'attr1': spec.get_group('group1').get_attribute('attr1')} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, type_map=TypeMap(), spec=spec ) expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}] self.assertListEqual(docval_args, expected) def test_update_docval_required_field_optional_parent(self): """Test that update_docval_args for a required field with an optional parent sets default: None.""" spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', groups=[ GroupSpec( name='group1', doc='required untyped group', attributes=[ AttributeSpec(name='attr1', doc='a string attribute', dtype='text') ], quantity='?' ) ] ) not_inherited_fields = {'attr1': spec.get_group('group1').get_attribute('attr1')} docval_args = list() CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, type_map=TypeMap(), spec=spec ) expected = [{'name': 'attr1', 'type': str, 'doc': 'a string attribute', 'default': None}] self.assertListEqual(docval_args, expected) def test_process_field_spec_overwrite(self): """Test that docval generation overwrites previous docval args.""" spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', attributes=[ AttributeSpec(name='attr1', doc='a string attribute', dtype='text', shape=[None]) ] ) not_inherited_fields = {'attr1': spec.get_attribute('attr1')} docval_args = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [[None], [None, None]]}, # this dict will be overwritten below {'name': 'attr2', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [[None], [None, None]]}] CustomClassGenerator.process_field_spec( classdict={}, docval_args=docval_args, parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr1', not_inherited_fields=not_inherited_fields, type_map=TypeMap(), spec=spec ) expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [None]}, {'name': 'attr2', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [[None], [None, None]]}] self.assertListEqual(docval_args, expected) def test_process_field_spec_link(self): """Test that processing a link spec does not set child=True in __fields__.""" classdict = {} not_inherited_fields = {'attr3': LinkSpec(name='attr3', target_type='EmptyBar', doc='a link')} CustomClassGenerator.process_field_spec( classdict=classdict, docval_args=[], parent_cls=EmptyBar, # <-- arbitrary class attr_name='attr3', not_inherited_fields=not_inherited_fields, type_map=self.type_map, spec=GroupSpec('dummy', 'doc') ) expected = {'__fields__': [{'name': 'attr3', 'doc': 'a link'}]} self.assertDictEqual(classdict, expected) def test_post_process_fixed_name(self): """Test that docval generation for a class with a fixed name does not contain a docval arg for name.""" spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', name='MyBaz', # <-- fixed name attributes=[ AttributeSpec( name='attr1', doc='a string attribute', dtype='text', shape=[None] ) ] ) classdict = {} bases = [Container] docval_args = [{'name': 'name', 'type': str, 'doc': 'name'}, {'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [None]}] CustomClassGenerator.post_process(classdict, bases, docval_args, spec) expected = [{'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [None]}] self.assertListEqual(docval_args, expected) def test_post_process_default_name(self): """Test that docval generation for a class with a default name has the default value for name set.""" spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', default_name='MyBaz', # <-- default name attributes=[ AttributeSpec( name='attr1', doc='a string attribute', dtype='text', shape=[None] ) ] ) classdict = {} bases = [Container] docval_args = [{'name': 'name', 'type': str, 'doc': 'name'}, {'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [None]}] CustomClassGenerator.post_process(classdict, bases, docval_args, spec) expected = [{'name': 'name', 'type': str, 'doc': 'name', 'default': 'MyBaz'}, {'name': 'attr1', 'type': ('array_data', 'data'), 'doc': 'a string attribute', 'shape': [None]}] self.assertListEqual(docval_args, expected) class TestMCIProcessFieldSpec(TestCase): def setUp(self): bar_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='EmptyBar' ) specs = [bar_spec] container_classes = {'EmptyBar': EmptyBar} self.type_map = create_test_type_map(specs, container_classes) def test_update_docval(self): spec = GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='*') classdict = dict() docval_args = [] not_inherited_fields = {'empty_bars': spec} MCIClassGenerator.process_field_spec( classdict=classdict, docval_args=docval_args, parent_cls=Container, attr_name='empty_bars', not_inherited_fields=not_inherited_fields, type_map=self.type_map, spec=spec ) expected = [ dict( attr='empty_bars', type=EmptyBar, add='add_empty_bars', get='get_empty_bars', create='create_empty_bars' ) ] self.assertEqual(classdict['__clsconf__'], expected) def test_update_init_zero_or_more(self): spec = GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='*') classdict = dict() docval_args = [] not_inherited_fields = {'empty_bars': spec} MCIClassGenerator.process_field_spec( classdict=classdict, docval_args=docval_args, parent_cls=Container, attr_name='empty_bars', not_inherited_fields=not_inherited_fields, type_map=self.type_map, spec=spec ) expected = [{'name': 'empty_bars', 'type': (list, tuple, dict, EmptyBar), 'doc': 'test multi', 'default': None}] self.assertListEqual(docval_args, expected) def test_update_init_one_or_more(self): spec = GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='+') classdict = dict() docval_args = [] not_inherited_fields = {'empty_bars': spec} MCIClassGenerator.process_field_spec( classdict=classdict, docval_args=docval_args, parent_cls=Container, attr_name='empty_bars', not_inherited_fields=not_inherited_fields, type_map=self.type_map, spec=spec ) expected = [{'name': 'empty_bars', 'type': (list, tuple, dict, EmptyBar), 'doc': 'test multi'}] self.assertListEqual(docval_args, expected) def test_post_process(self): multi_spec = GroupSpec( doc='A test extension that contains a multi', data_type_def='Multi', groups=[ GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='*') ], ) classdict = dict( __clsconf__=[ dict( attr='empty_bars', type=EmptyBar, add='add_empty_bars', get='get_empty_bars', create='create_empty_bars' ) ] ) bases = [Container] docval_args = [] MCIClassGenerator.post_process(classdict, bases, docval_args, multi_spec) self.assertEqual(bases, [MultiContainerInterface, Container]) def test_post_process_already_multi(self): class Multi1(MultiContainerInterface): pass multi_spec = GroupSpec( doc='A test extension that contains a multi and extends a multi', data_type_def='Multi2', data_type_inc='Multi1', groups=[ GroupSpec(data_type_inc='EmptyBar', doc='test multi', quantity='*') ], ) classdict = dict( __clsconf__=[ dict( attr='empty_bars', type=EmptyBar, add='add_empty_bars', get='get_empty_bars', create='create_empty_bars' ) ] ) bases = [Multi1] docval_args = [] MCIClassGenerator.post_process(classdict, bases, docval_args, multi_spec) self.assertEqual(bases, [Multi1]) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/test_convert_dtype.py0000644000655200065520000006211700000000000023530 0ustar00circlecicirclecifrom datetime import datetime import numpy as np from hdmf.backends.hdf5 import H5DataIO from hdmf.build import ObjectMapper from hdmf.data_utils import DataChunkIterator from hdmf.spec import DatasetSpec, RefSpec, DtypeSpec from hdmf.testing import TestCase class TestConvertDtype(TestCase): def test_value_none(self): spec = DatasetSpec('an example dataset', 'int', name='data') self.assertTupleEqual(ObjectMapper.convert_dtype(spec, None), (None, 'int')) spec = DatasetSpec('an example dataset', RefSpec(reftype='object', target_type='int'), name='data') self.assertTupleEqual(ObjectMapper.convert_dtype(spec, None), (None, 'object')) # do full matrix test of given value x and spec y, what does convert_dtype return? def test_convert_to_64bit_spec(self): """ Test that if given any value for a spec with a 64-bit dtype, convert_dtype will convert to the spec type. Also test that if the given value is not the same as the spec, convert_dtype raises a warning. """ spec_type = 'float64' value_types = ['double', 'float64'] self._test_convert_alias(spec_type, value_types) spec_type = 'float64' value_types = ['float', 'float32', 'long', 'int64', 'int', 'int32', 'int16', 'short', 'int8', 'uint64', 'uint', 'uint32', 'uint16', 'uint8', 'bool'] self._test_convert_higher_precision_helper(spec_type, value_types) spec_type = 'int64' value_types = ['long', 'int64'] self._test_convert_alias(spec_type, value_types) spec_type = 'int64' value_types = ['double', 'float64', 'float', 'float32', 'int', 'int32', 'int16', 'short', 'int8', 'uint64', 'uint', 'uint32', 'uint16', 'uint8', 'bool'] self._test_convert_higher_precision_helper(spec_type, value_types) spec_type = 'uint64' value_types = ['uint64'] self._test_convert_alias(spec_type, value_types) spec_type = 'uint64' value_types = ['double', 'float64', 'float', 'float32', 'long', 'int64', 'int', 'int32', 'int16', 'short', 'int8', 'uint', 'uint32', 'uint16', 'uint8', 'bool'] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_float32_spec(self): """Test conversion of various types to float32. If given a value with precision > float32 and float base type, convert_dtype will keep the higher precision. If given a value with 64-bit precision and different base type, convert_dtype will convert to float64. If given a value that is float32, convert_dtype will convert to float32. If given a value with precision <= float32, convert_dtype will convert to float32 and raise a warning. """ spec_type = 'float32' value_types = ['double', 'float64'] self._test_keep_higher_precision_helper(spec_type, value_types) value_types = ['long', 'int64', 'uint64'] expected_type = 'float64' self._test_change_basetype_helper(spec_type, value_types, expected_type) value_types = ['float', 'float32'] self._test_convert_alias(spec_type, value_types) value_types = ['int', 'int32', 'int16', 'short', 'int8', 'uint', 'uint32', 'uint16', 'uint8', 'bool'] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_int32_spec(self): """Test conversion of various types to int32. If given a value with precision > int32 and int base type, convert_dtype will keep the higher precision. If given a value with 64-bit precision and different base type, convert_dtype will convert to int64. If given a value that is int32, convert_dtype will convert to int32. If given a value with precision <= int32, convert_dtype will convert to int32 and raise a warning. """ spec_type = 'int32' value_types = ['int64', 'long'] self._test_keep_higher_precision_helper(spec_type, value_types) value_types = ['double', 'float64', 'uint64'] expected_type = 'int64' self._test_change_basetype_helper(spec_type, value_types, expected_type) value_types = ['int', 'int32'] self._test_convert_alias(spec_type, value_types) value_types = ['float', 'float32', 'int16', 'short', 'int8', 'uint', 'uint32', 'uint16', 'uint8', 'bool'] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_uint32_spec(self): """Test conversion of various types to uint32. If given a value with precision > uint32 and uint base type, convert_dtype will keep the higher precision. If given a value with 64-bit precision and different base type, convert_dtype will convert to uint64. If given a value that is uint32, convert_dtype will convert to uint32. If given a value with precision <= uint32, convert_dtype will convert to uint32 and raise a warning. """ spec_type = 'uint32' value_types = ['uint64'] self._test_keep_higher_precision_helper(spec_type, value_types) value_types = ['double', 'float64', 'long', 'int64'] expected_type = 'uint64' self._test_change_basetype_helper(spec_type, value_types, expected_type) value_types = ['uint', 'uint32'] self._test_convert_alias(spec_type, value_types) value_types = ['float', 'float32', 'int', 'int32', 'int16', 'short', 'int8', 'uint16', 'uint8', 'bool'] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_int16_spec(self): """Test conversion of various types to int16. If given a value with precision > int16 and int base type, convert_dtype will keep the higher precision. If given a value with 64-bit precision and different base type, convert_dtype will convert to int64. If given a value with 32-bit precision and different base type, convert_dtype will convert to int32. If given a value that is int16, convert_dtype will convert to int16. If given a value with precision <= int16, convert_dtype will convert to int16 and raise a warning. """ spec_type = 'int16' value_types = ['long', 'int64', 'int', 'int32'] self._test_keep_higher_precision_helper(spec_type, value_types) value_types = ['double', 'float64', 'uint64'] expected_type = 'int64' self._test_change_basetype_helper(spec_type, value_types, expected_type) value_types = ['float', 'float32', 'uint', 'uint32'] expected_type = 'int32' self._test_change_basetype_helper(spec_type, value_types, expected_type) value_types = ['int16', 'short'] self._test_convert_alias(spec_type, value_types) value_types = ['int8', 'uint16', 'uint8', 'bool'] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_uint16_spec(self): """Test conversion of various types to uint16. If given a value with precision > uint16 and uint base type, convert_dtype will keep the higher precision. If given a value with 64-bit precision and different base type, convert_dtype will convert to uint64. If given a value with 32-bit precision and different base type, convert_dtype will convert to uint32. If given a value that is uint16, convert_dtype will convert to uint16. If given a value with precision <= uint16, convert_dtype will convert to uint16 and raise a warning. """ spec_type = 'uint16' value_types = ['uint64', 'uint', 'uint32'] self._test_keep_higher_precision_helper(spec_type, value_types) value_types = ['double', 'float64', 'long', 'int64'] expected_type = 'uint64' self._test_change_basetype_helper(spec_type, value_types, expected_type) value_types = ['float', 'float32', 'int', 'int32'] expected_type = 'uint32' self._test_change_basetype_helper(spec_type, value_types, expected_type) value_types = ['uint16'] self._test_convert_alias(spec_type, value_types) value_types = ['int16', 'short', 'int8', 'uint8', 'bool'] self._test_convert_higher_precision_helper(spec_type, value_types) def test_convert_to_bool_spec(self): """Test conversion of various types to bool. If given a value with type bool, convert_dtype will convert to bool. If given a value with type int8/uint8, convert_dtype will convert to bool and raise a warning. Otherwise, convert_dtype will raise an error. """ spec_type = 'bool' value_types = ['bool'] self._test_convert_alias(spec_type, value_types) value_types = ['uint8', 'int8'] self._test_convert_higher_precision_helper(spec_type, value_types) value_types = ['double', 'float64', 'float', 'float32', 'long', 'int64', 'int', 'int32', 'int16', 'short', 'uint64', 'uint', 'uint32', 'uint16'] self._test_convert_mismatch_helper(spec_type, value_types) def _get_type(self, type_str): return ObjectMapper._ObjectMapper__dtypes[type_str] # apply ObjectMapper mapping string to dtype def _test_convert_alias(self, spec_type, value_types): data = 1 spec = DatasetSpec('an example dataset', spec_type, name='data') match = (self._get_type(spec_type)(data), self._get_type(spec_type)) for dtype in value_types: value = self._get_type(dtype)(data) # convert data to given dtype with self.subTest(dtype=dtype): ret = ObjectMapper.convert_dtype(spec, value) self.assertTupleEqual(ret, match) self.assertIs(ret[0].dtype.type, match[1]) def _test_convert_higher_precision_helper(self, spec_type, value_types): data = 1 spec = DatasetSpec('an example dataset', spec_type, name='data') match = (self._get_type(spec_type)(data), self._get_type(spec_type)) for dtype in value_types: value = self._get_type(dtype)(data) # convert data to given dtype with self.subTest(dtype=dtype): s = np.dtype(self._get_type(spec_type)) g = np.dtype(self._get_type(dtype)) msg = ("Spec 'data': Value with data type %s is being converted to data type %s as specified." % (g.name, s.name)) with self.assertWarnsWith(UserWarning, msg): ret = ObjectMapper.convert_dtype(spec, value) self.assertTupleEqual(ret, match) self.assertIs(ret[0].dtype.type, match[1]) def _test_keep_higher_precision_helper(self, spec_type, value_types): data = 1 spec = DatasetSpec('an example dataset', spec_type, name='data') for dtype in value_types: value = self._get_type(dtype)(data) match = (value, self._get_type(dtype)) with self.subTest(dtype=dtype): ret = ObjectMapper.convert_dtype(spec, value) self.assertTupleEqual(ret, match) self.assertIs(ret[0].dtype.type, match[1]) def _test_change_basetype_helper(self, spec_type, value_types, exp_type): data = 1 spec = DatasetSpec('an example dataset', spec_type, name='data') match = (self._get_type(exp_type)(data), self._get_type(exp_type)) for dtype in value_types: value = self._get_type(dtype)(data) # convert data to given dtype with self.subTest(dtype=dtype): s = np.dtype(self._get_type(spec_type)) e = np.dtype(self._get_type(exp_type)) g = np.dtype(self._get_type(dtype)) msg = ("Spec 'data': Value with data type %s is being converted to data type %s " "(min specification: %s)." % (g.name, e.name, s.name)) with self.assertWarnsWith(UserWarning, msg): ret = ObjectMapper.convert_dtype(spec, value) self.assertTupleEqual(ret, match) self.assertIs(ret[0].dtype.type, match[1]) def _test_convert_mismatch_helper(self, spec_type, value_types): data = 1 spec = DatasetSpec('an example dataset', spec_type, name='data') for dtype in value_types: value = self._get_type(dtype)(data) # convert data to given dtype with self.subTest(dtype=dtype): s = np.dtype(self._get_type(spec_type)) g = np.dtype(self._get_type(dtype)) msg = "expected %s, received %s - must supply %s" % (s.name, g.name, s.name) with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) def test_dci_input(self): spec = DatasetSpec('an example dataset', 'int64', name='data') value = DataChunkIterator(np.array([1, 2, 3], dtype=np.int32)) msg = "Spec 'data': Value with data type int32 is being converted to data type int64 as specified." with self.assertWarnsWith(UserWarning, msg): ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) self.assertEqual(ret_dtype, np.int64) spec = DatasetSpec('an example dataset', 'int16', name='data') value = DataChunkIterator(np.array([1, 2, 3], dtype=np.int32)) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) self.assertEqual(ret_dtype, np.int32) # increase precision def test_text_spec(self): text_spec_types = ['text', 'utf', 'utf8', 'utf-8'] for spec_type in text_spec_types: with self.subTest(spec_type=spec_type): spec = DatasetSpec('an example dataset', spec_type, name='data') value = 'a' ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), str) self.assertEqual(ret_dtype, 'utf8') value = b'a' ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, 'a') self.assertIs(type(ret), str) self.assertEqual(ret_dtype, 'utf8') value = ['a', 'b'] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(ret, value) self.assertIs(type(ret[0]), str) self.assertEqual(ret_dtype, 'utf8') value = np.array(['a', 'b']) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, value) self.assertEqual(ret_dtype, 'utf8') value = np.array(['a', 'b'], dtype='S1') ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, np.array(['a', 'b'], dtype='U1')) self.assertEqual(ret_dtype, 'utf8') value = [] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(ret, value) self.assertEqual(ret_dtype, 'utf8') value = 1 msg = "Expected unicode or ascii string, got " with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) value = DataChunkIterator(np.array(['a', 'b'])) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) self.assertEqual(ret_dtype, 'utf8') value = DataChunkIterator(np.array(['a', 'b'], dtype='S1')) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) self.assertEqual(ret_dtype, 'utf8') def test_ascii_spec(self): ascii_spec_types = ['ascii', 'bytes'] for spec_type in ascii_spec_types: with self.subTest(spec_type=spec_type): spec = DatasetSpec('an example dataset', spec_type, name='data') value = 'a' ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, b'a') self.assertIs(type(ret), bytes) self.assertEqual(ret_dtype, 'ascii') value = b'a' ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, b'a') self.assertIs(type(ret), bytes) self.assertEqual(ret_dtype, 'ascii') value = ['a', 'b'] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(ret, [b'a', b'b']) self.assertIs(type(ret[0]), bytes) self.assertEqual(ret_dtype, 'ascii') value = np.array(['a', 'b']) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, np.array(['a', 'b'], dtype='S1')) self.assertEqual(ret_dtype, 'ascii') value = np.array(['a', 'b'], dtype='S1') ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, value) self.assertEqual(ret_dtype, 'ascii') value = [] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(ret, value) self.assertEqual(ret_dtype, 'ascii') value = 1 msg = "Expected unicode or ascii string, got " with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) value = DataChunkIterator(np.array(['a', 'b'])) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) self.assertEqual(ret_dtype, 'ascii') value = DataChunkIterator(np.array(['a', 'b'], dtype='S1')) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) # no conversion self.assertIs(ret, value) self.assertEqual(ret_dtype, 'ascii') def test_no_spec(self): spec_type = None spec = DatasetSpec('an example dataset', spec_type, name='data') value = [1, 2, 3] ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(ret, value) self.assertIs(type(ret[0]), int) self.assertEqual(ret_dtype, int) value = np.uint64(4) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), np.uint64) self.assertEqual(ret_dtype, np.uint64) value = 'hello' ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), str) self.assertEqual(ret_dtype, 'utf8') value = b'hello' ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), bytes) self.assertEqual(ret_dtype, 'ascii') value = np.array(['aa', 'bb']) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, value) self.assertEqual(ret_dtype, 'utf8') value = np.array(['aa', 'bb'], dtype='S2') ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) np.testing.assert_array_equal(ret, value) self.assertEqual(ret_dtype, 'ascii') value = DataChunkIterator(data=[1, 2, 3]) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(ret.dtype.type, np.dtype(int).type) self.assertIs(type(ret.data[0]), int) self.assertEqual(ret_dtype, np.dtype(int).type) value = DataChunkIterator(data=['a', 'b']) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(ret.dtype.type, np.str_) self.assertIs(type(ret.data[0]), str) self.assertEqual(ret_dtype, 'utf8') value = H5DataIO(np.arange(30).reshape(5, 2, 3)) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(ret.data.dtype.type, np.dtype(int).type) self.assertEqual(ret_dtype, np.dtype(int).type) value = H5DataIO(['foo', 'bar']) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret.data[0]), str) self.assertEqual(ret_dtype, 'utf8') value = H5DataIO([b'foo', b'bar']) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret.data[0]), bytes) self.assertEqual(ret_dtype, 'ascii') value = [] msg = "Cannot infer dtype of empty list or tuple. Please use numpy array with specified dtype." with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) def test_numeric_spec(self): spec_type = 'numeric' spec = DatasetSpec('an example dataset', spec_type, name='data') value = np.uint64(4) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), np.uint64) self.assertEqual(ret_dtype, np.uint64) value = DataChunkIterator(data=[1, 2, 3]) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(ret.dtype.type, np.dtype(int).type) self.assertIs(type(ret.data[0]), int) self.assertEqual(ret_dtype, np.dtype(int).type) value = ['a', 'b'] msg = "Cannot convert from to 'numeric' specification dtype." with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) value = np.array(['a', 'b']) msg = "Cannot convert from to 'numeric' specification dtype." with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) value = [] msg = "Cannot infer dtype of empty list or tuple. Please use numpy array with specified dtype." with self.assertRaisesWith(ValueError, msg): ObjectMapper.convert_dtype(spec, value) def test_bool_spec(self): spec_type = 'bool' spec = DatasetSpec('an example dataset', spec_type, name='data') value = np.bool_(True) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), np.bool_) self.assertEqual(ret_dtype, np.bool_) value = True ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, value) self.assertIs(type(ret), np.bool_) self.assertEqual(ret_dtype, np.bool_) def test_override_type_int_restrict_precision(self): spec = DatasetSpec('an example dataset', 'int8', name='data') res = ObjectMapper.convert_dtype(spec, np.int64(1), 'int64') self.assertTupleEqual(res, (np.int64(1), np.int64)) def test_override_type_numeric_to_uint(self): spec = DatasetSpec('an example dataset', 'numeric', name='data') res = ObjectMapper.convert_dtype(spec, np.uint32(1), 'uint8') self.assertTupleEqual(res, (np.uint32(1), np.uint32)) def test_override_type_numeric_to_uint_list(self): spec = DatasetSpec('an example dataset', 'numeric', name='data') res = ObjectMapper.convert_dtype(spec, np.uint32((1, 2, 3)), 'uint8') np.testing.assert_array_equal(res[0], np.uint32((1, 2, 3))) self.assertEqual(res[1], np.uint32) def test_override_type_none_to_bool(self): spec = DatasetSpec('an example dataset', None, name='data') res = ObjectMapper.convert_dtype(spec, True, 'bool') self.assertTupleEqual(res, (True, np.bool_)) def test_compound_type(self): """Test that convert_dtype passes through arguments if spec dtype is a list without any validation.""" spec_type = [DtypeSpec('an int field', 'f1', 'int'), DtypeSpec('a float field', 'f2', 'float')] spec = DatasetSpec('an example dataset', spec_type, name='data') value = ['a', 1, 2.2] res, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertListEqual(res, value) self.assertListEqual(ret_dtype, spec_type) def test_isodatetime_spec(self): spec_type = 'isodatetime' spec = DatasetSpec('an example dataset', spec_type, name='data') # NOTE: datetime.isoformat is called on all values with a datetime spec before conversion # see ObjectMapper.get_attr_value value = datetime.isoformat(datetime(2020, 11, 10)) ret, ret_dtype = ObjectMapper.convert_dtype(spec, value) self.assertEqual(ret, b'2020-11-10T00:00:00') self.assertIs(type(ret), bytes) self.assertEqual(ret_dtype, 'ascii') ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/test_io_manager.py0000644000655200065520000003244100000000000022741 0ustar00circlecicirclecifrom abc import ABCMeta, abstractmethod from hdmf.build import GroupBuilder, DatasetBuilder, ObjectMapper, BuildManager, TypeMap, ContainerConfigurationError from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog from hdmf.spec.spec import ZERO_OR_MANY from hdmf.testing import TestCase from tests.unit.utils import Foo, FooBucket, CORE_NAMESPACE class FooMapper(ObjectMapper): """Maps nested 'attr2' attribute on dataset 'my_data' to Foo.attr2 in constructor and attribute map """ def __init__(self, spec): super().__init__(spec) my_data_spec = spec.get_dataset('my_data') self.map_spec('attr2', my_data_spec.get_attribute('attr2')) class TestBase(TestCase): def setUp(self): self.foo_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Foo', datasets=[ DatasetSpec( doc='an example dataset', dtype='int', name='my_data', attributes=[ AttributeSpec( name='attr2', doc='an example integer attribute', dtype='int' ) ] ) ], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')] ) self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) self.type_map.register_map(Foo, FooMapper) self.manager = BuildManager(self.type_map) class TestBuildManager(TestBase): def test_build(self): container_inst = Foo('my_foo', list(range(10)), 'value1', 10) expected = GroupBuilder( 'my_foo', datasets={ 'my_data': DatasetBuilder( 'my_data', list(range(10)), attributes={'attr2': 10})}, attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', 'object_id': container_inst.object_id}) builder1 = self.manager.build(container_inst) self.assertDictEqual(builder1, expected) def test_build_memoization(self): container_inst = Foo('my_foo', list(range(10)), 'value1', 10) expected = GroupBuilder( 'my_foo', datasets={ 'my_data': DatasetBuilder( 'my_data', list(range(10)), attributes={'attr2': 10})}, attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', 'object_id': container_inst.object_id}) builder1 = self.manager.build(container_inst) builder2 = self.manager.build(container_inst) self.assertDictEqual(builder1, expected) self.assertIs(builder1, builder2) def test_construct(self): builder = GroupBuilder( 'my_foo', datasets={ 'my_data': DatasetBuilder( 'my_data', list(range(10)), attributes={'attr2': 10})}, attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', 'object_id': -1}) container = self.manager.construct(builder) self.assertListEqual(container.my_data, list(range(10))) self.assertEqual(container.attr1, 'value1') self.assertEqual(container.attr2, 10) def test_construct_memoization(self): builder = GroupBuilder( 'my_foo', datasets={'my_data': DatasetBuilder( 'my_data', list(range(10)), attributes={'attr2': 10})}, attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', 'object_id': -1}) container1 = self.manager.construct(builder) container2 = self.manager.construct(builder) self.assertIs(container1, container2) class NestedBaseMixin(metaclass=ABCMeta): def setUp(self): super().setUp() self.foo_bucket = FooBucket('test_foo_bucket', [ Foo('my_foo1', list(range(10)), 'value1', 10), Foo('my_foo2', list(range(10, 20)), 'value2', 20)]) self.foo_builders = { 'my_foo1': GroupBuilder('my_foo1', datasets={'my_data': DatasetBuilder( 'my_data', list(range(10)), attributes={'attr2': 10})}, attributes={'attr1': 'value1', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', 'object_id': self.foo_bucket.foos['my_foo1'].object_id}), 'my_foo2': GroupBuilder('my_foo2', datasets={'my_data': DatasetBuilder( 'my_data', list(range(10, 20)), attributes={'attr2': 20})}, attributes={'attr1': 'value2', 'namespace': CORE_NAMESPACE, 'data_type': 'Foo', 'object_id': self.foo_bucket.foos['my_foo2'].object_id}) } self.setUpBucketBuilder() self.setUpBucketSpec() self.spec_catalog.register_spec(self.bucket_spec, 'test.yaml') self.type_map.register_container_type(CORE_NAMESPACE, 'FooBucket', FooBucket) self.type_map.register_map(FooBucket, self.setUpBucketMapper()) self.manager = BuildManager(self.type_map) @abstractmethod def setUpBucketBuilder(self): raise NotImplementedError('Cannot run test unless setUpBucketBuilder is implemented') @abstractmethod def setUpBucketSpec(self): raise NotImplementedError('Cannot run test unless setUpBucketSpec is implemented') @abstractmethod def setUpBucketMapper(self): raise NotImplementedError('Cannot run test unless setUpBucketMapper is implemented') def test_build(self): ''' Test default mapping for an Container that has an Container as an attribute value ''' builder = self.manager.build(self.foo_bucket) self.assertDictEqual(builder, self.bucket_builder) def test_construct(self): container = self.manager.construct(self.bucket_builder) self.assertEqual(container, self.foo_bucket) class TestNestedContainersNoSubgroups(NestedBaseMixin, TestBase): ''' Test BuildManager.build and BuildManager.construct when the Container contains other Containers, but does not keep them in additional subgroups ''' def setUpBucketBuilder(self): self.bucket_builder = GroupBuilder( 'test_foo_bucket', groups=self.foo_builders, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'FooBucket', 'object_id': self.foo_bucket.object_id}) def setUpBucketSpec(self): self.bucket_spec = GroupSpec('A test group specification for a data type containing data type', name="test_foo_bucket", data_type_def='FooBucket', groups=[GroupSpec( 'the Foos in this bucket', data_type_inc='Foo', quantity=ZERO_OR_MANY)]) def setUpBucketMapper(self): return ObjectMapper class TestNestedContainersSubgroup(NestedBaseMixin, TestBase): ''' Test BuildManager.build and BuildManager.construct when the Container contains other Containers that are stored in a subgroup ''' def setUpBucketBuilder(self): tmp_builder = GroupBuilder('foo_holder', groups=self.foo_builders) self.bucket_builder = GroupBuilder( 'test_foo_bucket', groups={'foos': tmp_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'FooBucket', 'object_id': self.foo_bucket.object_id}) def setUpBucketSpec(self): tmp_spec = GroupSpec( 'A subgroup for Foos', name='foo_holder', groups=[GroupSpec('the Foos in this bucket', data_type_inc='Foo', quantity=ZERO_OR_MANY)]) self.bucket_spec = GroupSpec('A test group specification for a data type containing data type', name="test_foo_bucket", data_type_def='FooBucket', groups=[tmp_spec]) def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) self.unmap(spec.get_group('foo_holder')) self.map_spec('foos', spec.get_group('foo_holder').get_data_type('Foo')) return BucketMapper class TestNestedContainersSubgroupSubgroup(NestedBaseMixin, TestBase): ''' Test BuildManager.build and BuildManager.construct when the Container contains other Containers that are stored in a subgroup in a subgroup ''' def setUpBucketBuilder(self): tmp_builder = GroupBuilder('foo_holder', groups=self.foo_builders) tmp_builder = GroupBuilder('foo_holder_holder', groups={'foo_holder': tmp_builder}) self.bucket_builder = GroupBuilder( 'test_foo_bucket', groups={'foo_holder': tmp_builder}, attributes={'namespace': CORE_NAMESPACE, 'data_type': 'FooBucket', 'object_id': self.foo_bucket.object_id}) def setUpBucketSpec(self): tmp_spec = GroupSpec('A subgroup for Foos', name='foo_holder', groups=[GroupSpec('the Foos in this bucket', data_type_inc='Foo', quantity=ZERO_OR_MANY)]) tmp_spec = GroupSpec('A subgroup to hold the subgroup', name='foo_holder_holder', groups=[tmp_spec]) self.bucket_spec = GroupSpec('A test group specification for a data type containing data type', name="test_foo_bucket", data_type_def='FooBucket', groups=[tmp_spec]) def setUpBucketMapper(self): class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) self.unmap(spec.get_group('foo_holder_holder')) self.unmap(spec.get_group('foo_holder_holder').get_group('foo_holder')) self.map_spec('foos', spec.get_group('foo_holder_holder').get_group('foo_holder').get_data_type('Foo')) return BucketMapper def test_build(self): ''' Test default mapping for an Container that has an Container as an attribute value ''' builder = self.manager.build(self.foo_bucket) self.assertDictEqual(builder, self.bucket_builder) def test_construct(self): container = self.manager.construct(self.bucket_builder) self.assertEqual(container, self.foo_bucket) class TestNoAttribute(TestBase): def test_build(self): """Test that an error is raised when a spec is mapped to a non-existent container attribute.""" class Unmapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) self.map_spec("unknown", self.spec.get_dataset('my_data')) self.type_map.register_map(Foo, Unmapper) # override container_inst = Foo('my_foo', list(range(10)), 'value1', 10) msg = ("Foo 'my_foo' does not have attribute 'unknown' for mapping to spec: %s" % self.foo_spec.get_dataset('my_data')) with self.assertRaisesWith(ContainerConfigurationError, msg): self.manager.build(container_inst) class TestTypeMap(TestBase): def test_get_ns_dt_missing(self): bldr = GroupBuilder('my_foo', attributes={'attr1': 'value1'}) dt = self.type_map.get_builder_dt(bldr) ns = self.type_map.get_builder_ns(bldr) self.assertIsNone(dt) self.assertIsNone(ns) def test_get_ns_dt(self): bldr = GroupBuilder('my_foo', attributes={'attr1': 'value1', 'namespace': 'CORE', 'data_type': 'Foo', 'object_id': -1}) dt = self.type_map.get_builder_dt(bldr) ns = self.type_map.get_builder_ns(bldr) self.assertEqual(dt, 'Foo') self.assertEqual(ns, 'CORE') # TODO: class TestWildCardNamedSpecs(TestCase): pass ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/test_io_map.py0000644000655200065520000010443500000000000022107 0ustar00circlecicircleciimport unittest from abc import ABCMeta, abstractmethod from hdmf import Container from hdmf.backends.hdf5 import H5DataIO from hdmf.build import (GroupBuilder, DatasetBuilder, ObjectMapper, BuildManager, TypeMap, LinkBuilder, ReferenceBuilder, MissingRequiredBuildWarning, OrphanContainerBuildError, ContainerConfigurationError) from hdmf.spec import GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, RefSpec from hdmf.testing import TestCase from hdmf.utils import docval, getargs from tests.unit.utils import CORE_NAMESPACE class Bar(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this Bar'}, {'name': 'data', 'type': ('data', 'array_data'), 'doc': 'some data'}, {'name': 'attr1', 'type': str, 'doc': 'an attribute'}, {'name': 'attr2', 'type': int, 'doc': 'another attribute'}, {'name': 'attr3', 'type': float, 'doc': 'a third attribute', 'default': 3.14}, {'name': 'foo', 'type': 'Foo', 'doc': 'a group', 'default': None}) def __init__(self, **kwargs): name, data, attr1, attr2, attr3, foo = getargs('name', 'data', 'attr1', 'attr2', 'attr3', 'foo', kwargs) super().__init__(name=name) self.__data = data self.__attr1 = attr1 self.__attr2 = attr2 self.__attr3 = attr3 self.__foo = foo if self.__foo is not None and self.__foo.parent is None: self.__foo.parent = self def __eq__(self, other): attrs = ('name', 'data', 'attr1', 'attr2', 'attr3', 'foo') return all(getattr(self, a) == getattr(other, a) for a in attrs) def __str__(self): attrs = ('name', 'data', 'attr1', 'attr2', 'attr3', 'foo') return ','.join('%s=%s' % (a, getattr(self, a)) for a in attrs) @property def data_type(self): return 'Bar' @property def data(self): return self.__data @property def attr1(self): return self.__attr1 @property def attr2(self): return self.__attr2 @property def attr3(self): return self.__attr3 @property def foo(self): return self.__foo def remove_foo(self): if self is self.__foo.parent: self._remove_child(self.__foo) class Foo(Container): @property def data_type(self): return 'Foo' class TestGetSubSpec(TestCase): def setUp(self): self.bar_spec = GroupSpec('A test group specification with a data type', data_type_def='Bar') spec_catalog = SpecCatalog() spec_catalog.register_spec(self.bar_spec, 'test.yaml') namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) self.type_map = TypeMap(namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) def test_get_subspec_data_type_noname(self): parent_spec = GroupSpec('Something to hold a Bar', 'bar_bucket', groups=[self.bar_spec]) sub_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'namespace': CORE_NAMESPACE, 'object_id': -1}) GroupBuilder('bar_bucket', groups={'my_bar': sub_builder}) result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, self.bar_spec) def test_get_subspec_named(self): child_spec = GroupSpec('A test group specification with a data type', 'my_subgroup') parent_spec = GroupSpec('Something to hold a Bar', 'my_group', groups=[child_spec]) sub_builder = GroupBuilder('my_subgroup', attributes={'data_type': 'Bar', 'namespace': CORE_NAMESPACE, 'object_id': -1}) GroupBuilder('my_group', groups={'my_bar': sub_builder}) result = self.type_map.get_subspec(parent_spec, sub_builder) self.assertIs(result, child_spec) class TestTypeMap(TestCase): def setUp(self): self.bar_spec = GroupSpec('A test group specification with a data type', data_type_def='Bar') self.foo_spec = GroupSpec('A test group specification with data type Foo', data_type_def='Foo') self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) def test_get_map_unique_mappers(self): bar_inst = Bar('my_bar', list(range(10)), 'value1', 10) foo_inst = Foo(name='my_foo') bar_mapper = self.type_map.get_map(bar_inst) foo_mapper = self.type_map.get_map(foo_inst) self.assertIsNot(bar_mapper, foo_mapper) def test_get_map(self): container_inst = Bar('my_bar', list(range(10)), 'value1', 10) mapper = self.type_map.get_map(container_inst) self.assertIsInstance(mapper, ObjectMapper) self.assertIs(mapper.spec, self.bar_spec) mapper2 = self.type_map.get_map(container_inst) self.assertIs(mapper, mapper2) def test_get_map_register(self): class MyMap(ObjectMapper): pass self.type_map.register_map(Bar, MyMap) container_inst = Bar('my_bar', list(range(10)), 'value1', 10) mapper = self.type_map.get_map(container_inst) self.assertIs(mapper.spec, self.bar_spec) self.assertIsInstance(mapper, MyMap) class BarMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) data_spec = spec.get_dataset('data') self.map_spec('attr2', data_spec.get_attribute('attr2')) class TestMapStrings(TestCase): def customSetUp(self, bar_spec): spec_catalog = SpecCatalog() spec_catalog.register_spec(bar_spec, 'test.yaml') namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) return type_map def test_build_1d(self): bar_spec = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', 'text', name='data', shape=(None,), attributes=[AttributeSpec( 'attr2', 'an example integer attribute', 'int')])], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) type_map = self.customSetUp(bar_spec) type_map.register_map(Bar, BarMapper) bar_inst = Bar('my_bar', ['a', 'b', 'c', 'd'], 'value1', 10) builder = type_map.build(bar_inst) self.assertEqual(builder.get('data').data, ['a', 'b', 'c', 'd']) def test_build_scalar(self): bar_spec = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', 'text', name='data', attributes=[AttributeSpec( 'attr2', 'an example integer attribute', 'int')])], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) type_map = self.customSetUp(bar_spec) type_map.register_map(Bar, BarMapper) bar_inst = Bar('my_bar', ['a', 'b', 'c', 'd'], 'value1', 10) builder = type_map.build(bar_inst) self.assertEqual(builder.get('data').data, "['a', 'b', 'c', 'd']") def test_build_dataio(self): bar_spec = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', 'text', name='data', shape=(None,), attributes=[AttributeSpec( 'attr2', 'an example integer attribute', 'int')])], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) type_map = self.customSetUp(bar_spec) type_map.register_map(Bar, BarMapper) bar_inst = Bar('my_bar', H5DataIO(['a', 'b', 'c', 'd'], chunks=True), 'value1', 10) builder = type_map.build(bar_inst) self.assertIsInstance(builder.get('data').data, H5DataIO) class ObjectMapperMixin(metaclass=ABCMeta): def setUp(self): self.setUpBarSpec() self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.bar_spec) @abstractmethod def setUpBarSpec(self): raise NotImplementedError('Cannot run test unless setUpBarSpec is implemented') def test_default_mapping(self): attr_map = self.mapper.get_attr_names(self.bar_spec) keys = set(attr_map.keys()) for key in keys: with self.subTest(key=key): self.assertIs(attr_map[key], self.mapper.get_attr_spec(key)) self.assertIs(attr_map[key], self.mapper.get_carg_spec(key)) class TestObjectMapperNested(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', 'int', name='data', attributes=[AttributeSpec( 'attr2', 'an example integer attribute', 'int')])], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) def test_build(self): ''' Test default mapping functionality when object attributes map to an attribute deeper than top-level Builder ''' container_inst = Bar('my_bar', list(range(10)), 'value1', 10) expected = GroupBuilder( name='my_bar', datasets={'data': DatasetBuilder( name='data', data=list(range(10)), attributes={'attr2': 10} )}, attributes={'attr1': 'value1'} ) self._remap_nested_attr() builder = self.mapper.build(container_inst, self.manager) self.assertBuilderEqual(builder, expected) def test_construct(self): ''' Test default mapping functionality when object attributes map to an attribute deeper than top-level Builder ''' expected = Bar('my_bar', list(range(10)), 'value1', 10) builder = GroupBuilder( name='my_bar', datasets={'data': DatasetBuilder( name='data', data=list(range(10)), attributes={'attr2': 10} )}, attributes={'attr1': 'value1', 'data_type': 'Bar', 'namespace': CORE_NAMESPACE, 'object_id': expected.object_id} ) self._remap_nested_attr() container = self.mapper.construct(builder, self.manager) self.assertEqual(container, expected) def test_default_mapping_keys(self): attr_map = self.mapper.get_attr_names(self.bar_spec) keys = set(attr_map.keys()) expected = {'attr1', 'data', 'data__attr2'} self.assertSetEqual(keys, expected) def test_remap_keys(self): self._remap_nested_attr() self.assertEqual(self.mapper.get_attr_spec('attr2'), self.mapper.spec.get_dataset('data').get_attribute('attr2')) self.assertEqual(self.mapper.get_attr_spec('attr1'), self.mapper.spec.get_attribute('attr1')) self.assertEqual(self.mapper.get_attr_spec('data'), self.mapper.spec.get_dataset('data')) def _remap_nested_attr(self): data_spec = self.mapper.spec.get_dataset('data') self.mapper.map_spec('attr2', data_spec.get_attribute('attr2')) class TestObjectMapperNoNesting(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', 'int', name='data')], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), AttributeSpec('attr2', 'an example integer attribute', 'int')]) def test_build(self): ''' Test default mapping functionality when no attributes are nested ''' container = Bar('my_bar', list(range(10)), 'value1', 10) builder = self.mapper.build(container, self.manager) expected = GroupBuilder('my_bar', datasets={'data': DatasetBuilder('data', list(range(10)))}, attributes={'attr1': 'value1', 'attr2': 10}) self.assertBuilderEqual(builder, expected) def test_build_empty(self): ''' Test default mapping functionality when no attributes are nested ''' container = Bar('my_bar', [], 'value1', 10) builder = self.mapper.build(container, self.manager) expected = GroupBuilder('my_bar', datasets={'data': DatasetBuilder('data', [])}, attributes={'attr1': 'value1', 'attr2': 10}) self.assertBuilderEqual(builder, expected) def test_construct(self): expected = Bar('my_bar', list(range(10)), 'value1', 10) builder = GroupBuilder('my_bar', datasets={'data': DatasetBuilder('data', list(range(10)))}, attributes={'attr1': 'value1', 'attr2': 10, 'data_type': 'Bar', 'namespace': CORE_NAMESPACE, 'object_id': expected.object_id}) container = self.mapper.construct(builder, self.manager) self.assertEqual(container, expected) def test_default_mapping_keys(self): attr_map = self.mapper.get_attr_names(self.bar_spec) keys = set(attr_map.keys()) expected = {'attr1', 'data', 'attr2'} self.assertSetEqual(keys, expected) class TestObjectMapperContainer(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec('A test group specification with a data type', data_type_def='Bar', groups=[GroupSpec('an example group', data_type_def='Foo')], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), AttributeSpec('attr2', 'an example integer attribute', 'int')]) def test_default_mapping_keys(self): attr_map = self.mapper.get_attr_names(self.bar_spec) keys = set(attr_map.keys()) expected = {'attr1', 'foo', 'attr2'} self.assertSetEqual(keys, expected) class TestLinkedContainer(TestCase): def setUp(self): self.foo_spec = GroupSpec('A test group specification with data type Foo', data_type_def='Foo') self.bar_spec = GroupSpec('A test group specification with a data type Bar', data_type_def='Bar', groups=[self.foo_spec], datasets=[DatasetSpec('an example dataset', 'int', name='data')], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), AttributeSpec('attr2', 'an example integer attribute', 'int')]) self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) self.manager = BuildManager(self.type_map) self.foo_mapper = ObjectMapper(self.foo_spec) self.bar_mapper = ObjectMapper(self.bar_spec) def test_build_child_link(self): ''' Test default mapping functionality when one container contains a child link to another container ''' foo_inst = Foo('my_foo') bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10, foo=foo_inst) # bar_inst2.foo should link to bar_inst1.foo bar_inst2 = Bar('my_bar2', list(range(10)), 'value1', 10, foo=foo_inst) foo_builder = self.foo_mapper.build(foo_inst, self.manager) bar1_builder = self.bar_mapper.build(bar_inst1, self.manager) bar2_builder = self.bar_mapper.build(bar_inst2, self.manager) foo_expected = GroupBuilder('my_foo') inner_foo_builder = GroupBuilder('my_foo', attributes={'data_type': 'Foo', 'namespace': CORE_NAMESPACE, 'object_id': foo_inst.object_id}) bar1_expected = GroupBuilder('my_bar1', datasets={'data': DatasetBuilder('data', list(range(10)))}, groups={'foo': inner_foo_builder}, attributes={'attr1': 'value1', 'attr2': 10}) link_foo_builder = LinkBuilder(builder=inner_foo_builder) bar2_expected = GroupBuilder('my_bar2', datasets={'data': DatasetBuilder('data', list(range(10)))}, links={'foo': link_foo_builder}, attributes={'attr1': 'value1', 'attr2': 10}) self.assertBuilderEqual(foo_builder, foo_expected) self.assertBuilderEqual(bar1_builder, bar1_expected) self.assertBuilderEqual(bar2_builder, bar2_expected) @unittest.expectedFailure def test_build_broken_link_parent(self): ''' Test that building a container with a broken link that has a parent raises an error. ''' foo_inst = Foo('my_foo') Bar('my_bar1', list(range(10)), 'value1', 10, foo=foo_inst) # foo_inst.parent is this bar # bar_inst2.foo should link to bar_inst1.foo bar_inst2 = Bar('my_bar2', list(range(10)), 'value1', 10, foo=foo_inst) # TODO bar_inst.foo.parent exists but is never built - this is a tricky edge case that should raise an error with self.assertRaises(OrphanContainerBuildError): self.bar_mapper.build(bar_inst2, self.manager) def test_build_broken_link_no_parent(self): ''' Test that building a container with a broken link that has no parent raises an error. ''' foo_inst = Foo('my_foo') bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10, foo=foo_inst) # foo_inst.parent is this bar # bar_inst2.foo should link to bar_inst1.foo bar_inst2 = Bar('my_bar2', list(range(10)), 'value1', 10, foo=foo_inst) bar_inst1.remove_foo() msg = ("my_bar2 (my_bar2): Linked Foo 'my_foo' has no parent. Remove the link or ensure the linked container " "is added properly.") with self.assertRaisesWith(OrphanContainerBuildError, msg): self.bar_mapper.build(bar_inst2, self.manager) class TestReference(TestCase): def setUp(self): self.foo_spec = GroupSpec('A test group specification with data type Foo', data_type_def='Foo') self.bar_spec = GroupSpec('A test group specification with a data type Bar', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', 'int', name='data')], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), AttributeSpec('attr2', 'an example integer attribute', 'int'), AttributeSpec('foo', 'a referenced foo', RefSpec('Foo', 'object'), required=False)]) self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') self.spec_catalog.register_spec(self.bar_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) self.type_map.register_container_type(CORE_NAMESPACE, 'Bar', Bar) self.manager = BuildManager(self.type_map) self.foo_mapper = ObjectMapper(self.foo_spec) self.bar_mapper = ObjectMapper(self.bar_spec) def test_build_attr_ref(self): ''' Test default mapping functionality when one container contains an attribute reference to another container. ''' foo_inst = Foo('my_foo') bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10, foo=foo_inst) bar_inst2 = Bar('my_bar2', list(range(10)), 'value1', 10) foo_builder = self.manager.build(foo_inst, root=True) bar1_builder = self.manager.build(bar_inst1, root=True) # adds refs bar2_builder = self.manager.build(bar_inst2, root=True) foo_expected = GroupBuilder('my_foo', attributes={'data_type': 'Foo', 'namespace': CORE_NAMESPACE, 'object_id': foo_inst.object_id}) bar1_expected = GroupBuilder('n/a', # name doesn't matter datasets={'data': DatasetBuilder('data', list(range(10)))}, attributes={'attr1': 'value1', 'attr2': 10, 'foo': ReferenceBuilder(foo_expected), 'data_type': 'Bar', 'namespace': CORE_NAMESPACE, 'object_id': bar_inst1.object_id}) bar2_expected = GroupBuilder('n/a', # name doesn't matter datasets={'data': DatasetBuilder('data', list(range(10)))}, attributes={'attr1': 'value1', 'attr2': 10, 'data_type': 'Bar', 'namespace': CORE_NAMESPACE, 'object_id': bar_inst2.object_id}) self.assertDictEqual(foo_builder, foo_expected) self.assertDictEqual(bar1_builder, bar1_expected) self.assertDictEqual(bar2_builder, bar2_expected) def test_build_attr_ref_invalid(self): ''' Test default mapping functionality when one container contains an attribute reference to another container. ''' bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) bar_inst1._Bar__foo = object() # make foo object a non-container type msg = "invalid type for reference 'foo' () - must be AbstractContainer" with self.assertRaisesWith(ValueError, msg): self.bar_mapper.build(bar_inst1, self.manager) class TestMissingRequiredAttribute(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec( doc='A test group specification with a data type Bar', data_type_def='Bar', attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), AttributeSpec('attr2', 'an example integer attribute', 'int')] ) def test_required_attr_missing(self): """Test mapping when one container is missing a required attribute.""" bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) bar_inst1._Bar__attr1 = None # make attr1 attribute None msg = "Bar 'my_bar1' is missing required value for attribute 'attr1'." with self.assertWarnsWith(MissingRequiredBuildWarning, msg): builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( name='my_bar1', attributes={'attr2': 10} ) self.assertBuilderEqual(expected, builder) class TestMissingRequiredAttributeRef(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec( doc='A test group specification with a data type Bar', data_type_def='Bar', attributes=[AttributeSpec('foo', 'a referenced foo', RefSpec('Foo', 'object'))] ) def test_required_attr_ref_missing(self): """Test mapping when one container is missing a required attribute reference.""" bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) msg = "Bar 'my_bar1' is missing required value for attribute 'foo'." with self.assertWarnsWith(MissingRequiredBuildWarning, msg): builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( name='my_bar1', ) self.assertBuilderEqual(expected, builder) class TestMissingRequiredDataset(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec( doc='A test group specification with a data type Bar', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', 'int', name='data')] ) def test_required_dataset_missing(self): """Test mapping when one container is missing a required dataset.""" bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) bar_inst1._Bar__data = None # make data dataset None msg = "Bar 'my_bar1' is missing required value for attribute 'data'." with self.assertWarnsWith(MissingRequiredBuildWarning, msg): builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( name='my_bar1', ) self.assertBuilderEqual(expected, builder) class TestMissingRequiredGroup(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec( doc='A test group specification with a data type Bar', data_type_def='Bar', groups=[GroupSpec('foo', data_type_inc='Foo')] ) def test_required_group_missing(self): """Test mapping when one container is missing a required group.""" bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) msg = "Bar 'my_bar1' is missing required value for attribute 'foo'." with self.assertWarnsWith(MissingRequiredBuildWarning, msg): builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( name='my_bar1', ) self.assertBuilderEqual(expected, builder) class TestRequiredEmptyGroup(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec( doc='A test group specification with a data type Bar', data_type_def='Bar', groups=[GroupSpec(name='empty', doc='empty group')], ) def test_required_group_empty(self): """Test mapping when one container has a required empty group.""" bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( name='my_bar1', groups={'empty': GroupBuilder('empty')}, ) self.assertBuilderEqual(expected, builder) class TestOptionalEmptyGroup(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec( doc='A test group specification with a data type Bar', data_type_def='Bar', groups=[GroupSpec( name='empty', doc='empty group', quantity='?', attributes=[AttributeSpec('attr3', 'an optional float attribute', 'float', required=False)] )] ) def test_optional_group_empty(self): """Test mapping when one container has an optional empty group.""" self.mapper.map_spec('attr3', self.mapper.spec.get_group('empty').get_attribute('attr3')) bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) bar_inst1._Bar__attr3 = None # force attr3 to be None builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( name='my_bar1', ) self.assertBuilderEqual(expected, builder) def test_optional_group_not_empty(self): """Test mapping when one container has an optional not empty group.""" self.mapper.map_spec('attr3', self.mapper.spec.get_group('empty').get_attribute('attr3')) bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10, attr3=1.23) builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( name='my_bar1', groups={'empty': GroupBuilder( name='empty', attributes={'attr3': 1.23}, )}, ) self.assertBuilderEqual(expected, builder) class TestFixedAttributeValue(ObjectMapperMixin, TestCase): def setUpBarSpec(self): self.bar_spec = GroupSpec( doc='A test group specification with a data type Bar', data_type_def='Bar', attributes=[AttributeSpec('attr1', 'an example string attribute', 'text', value='hi'), AttributeSpec('attr2', 'an example integer attribute', 'int')] ) def test_required_attr_missing(self): """Test mapping when one container has a required attribute with a fixed value.""" bar_inst1 = Bar('my_bar1', list(range(10)), 'value1', 10) # attr1=value1 is not processed builder = self.mapper.build(bar_inst1, self.manager) expected = GroupBuilder( name='my_bar1', attributes={'attr1': 'hi', 'attr2': 10} ) self.assertBuilderEqual(builder, expected) class TestObjectMapperBadValue(TestCase): def test_bad_value(self): """Test that an error is raised if the container attribute value for a spec with a data type is not a container or collection of containers. """ class Qux(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this Qux'}, {'name': 'foo', 'type': int, 'doc': 'a group'}) def __init__(self, **kwargs): name, foo = getargs('name', 'foo', kwargs) super().__init__(name=name) self.__foo = foo if isinstance(foo, Foo): self.__foo.parent = self @property def foo(self): return self.__foo self.qux_spec = GroupSpec( doc='A test group specification with data type Qux', data_type_def='Qux', groups=[GroupSpec('an example dataset', data_type_inc='Foo')] ) self.foo_spec = GroupSpec('A test group specification with data type Foo', data_type_def='Foo') self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.qux_spec, 'test.yaml') self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Qux', Qux) self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.qux_spec) container = Qux('my_qux', foo=1) msg = "Qux 'my_qux' attribute 'foo' has unexpected type." with self.assertRaisesWith(ContainerConfigurationError, msg): self.mapper.build(container, self.manager) # TODO test passing a Container/Data/other object for a non-container/data array spec ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/build_tests/test_io_map_data.py0000644000655200065520000004661500000000000023105 0ustar00circlecicircleciimport os import h5py import numpy as np from hdmf import Container, Data from hdmf.backends.hdf5 import H5DataIO from hdmf.build import (GroupBuilder, DatasetBuilder, ObjectMapper, BuildManager, TypeMap, ReferenceBuilder, ReferenceTargetNotBuiltError) from hdmf.data_utils import DataChunkIterator from hdmf.spec import (AttributeSpec, DatasetSpec, DtypeSpec, GroupSpec, SpecCatalog, SpecNamespace, NamespaceCatalog, RefSpec) from hdmf.spec.spec import ZERO_OR_MANY from hdmf.testing import TestCase from hdmf.utils import docval, getargs, call_docval_func from tests.unit.utils import Foo, CORE_NAMESPACE class Baz(Data): @docval({'name': 'name', 'type': str, 'doc': 'the name of this Baz'}, {'name': 'data', 'type': (list, h5py.Dataset, 'data', 'array_data'), 'doc': 'some data'}, {'name': 'baz_attr', 'type': str, 'doc': 'an attribute'}) def __init__(self, **kwargs): name, data, baz_attr = getargs('name', 'data', 'baz_attr', kwargs) super().__init__(name=name, data=data) self.__baz_attr = baz_attr @property def baz_attr(self): return self.__baz_attr class BazHolder(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this Baz'}, {'name': 'bazs', 'type': list, 'doc': 'some Baz data', 'default': list()}) def __init__(self, **kwargs): name, bazs = getargs('name', 'bazs', kwargs) super().__init__(name=name) self.__bazs = {b.name: b for b in bazs} # note: collections of groups are unordered in HDF5 for b in bazs: b.parent = self @property def bazs(self): return self.__bazs class BazSpecMixin: def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Baz', Baz) self.type_map.register_map(Baz, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): raise NotImplementedError('Test must implement this method.') class TestDataMap(BazSpecMixin, TestCase): def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Baz', Baz) self.type_map.register_map(Baz, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): self.baz_spec = DatasetSpec( doc='an Baz type', dtype='int', name='MyBaz', data_type_def='Baz', shape=[None], attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] ) def test_build(self): ''' Test default mapping functionality when no attributes are nested ''' container = Baz('MyBaz', list(range(10)), 'abcdefghijklmnopqrstuvwxyz') builder = self.mapper.build(container, self.manager) expected = DatasetBuilder('MyBaz', list(range(10)), attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz'}) self.assertBuilderEqual(builder, expected) def test_build_empty_data(self): """Test building of a Data object with empty data.""" baz_inc_spec = DatasetSpec(doc='doc', data_type_inc='Baz', quantity=ZERO_OR_MANY) baz_holder_spec = GroupSpec(doc='doc', data_type_def='BazHolder', datasets=[baz_inc_spec]) self.spec_catalog.register_spec(baz_holder_spec, 'test.yaml') self.type_map.register_container_type(CORE_NAMESPACE, 'BazHolder', BazHolder) self.holder_mapper = ObjectMapper(baz_holder_spec) baz = Baz('MyBaz', [], 'abcdefghijklmnopqrstuvwxyz') holder = BazHolder('holder', [baz]) builder = self.holder_mapper.build(holder, self.manager) expected = GroupBuilder( name='holder', datasets=[DatasetBuilder( name='MyBaz', data=[], attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz', 'data_type': 'Baz', 'namespace': 'test_core', 'object_id': baz.object_id} )] ) self.assertBuilderEqual(builder, expected) def test_append(self): with h5py.File('test.h5', 'w') as file: test_ds = file.create_dataset('test_ds', data=[1, 2, 3], chunks=True, maxshape=(None,)) container = Baz('MyBaz', test_ds, 'abcdefghijklmnopqrstuvwxyz') container.append(4) np.testing.assert_array_equal(container[:], [1, 2, 3, 4]) os.remove('test.h5') def test_extend(self): with h5py.File('test.h5', 'w') as file: test_ds = file.create_dataset('test_ds', data=[1, 2, 3], chunks=True, maxshape=(None,)) container = Baz('MyBaz', test_ds, 'abcdefghijklmnopqrstuvwxyz') container.extend([4, 5]) np.testing.assert_array_equal(container[:], [1, 2, 3, 4, 5]) os.remove('test.h5') class BazScalar(Data): @docval({'name': 'name', 'type': str, 'doc': 'the name of this BazScalar'}, {'name': 'data', 'type': int, 'doc': 'some data'}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) class TestDataMapScalar(TestCase): def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'BazScalar', BazScalar) self.type_map.register_map(BazScalar, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): self.baz_spec = DatasetSpec( doc='a BazScalar type', dtype='int', name='MyBaz', data_type_def='BazScalar' ) def test_construct_scalar_dataset(self): """Test constructing a Data object with an h5py.Dataset with shape (1, ) for scalar spec.""" with h5py.File('test.h5', 'w') as file: test_ds = file.create_dataset('test_ds', data=[1]) expected = BazScalar( name='MyBaz', data=1, ) builder = DatasetBuilder( name='MyBaz', data=test_ds, attributes={'data_type': 'BazScalar', 'namespace': CORE_NAMESPACE, 'object_id': expected.object_id}, ) container = self.mapper.construct(builder, self.manager) self.assertTrue(np.issubdtype(type(container.data), np.integer)) # as opposed to h5py.Dataset self.assertContainerEqual(container, expected) os.remove('test.h5') class BazScalarCompound(Data): @docval({'name': 'name', 'type': str, 'doc': 'the name of this BazScalar'}, {'name': 'data', 'type': 'array_data', 'doc': 'some data'}) def __init__(self, **kwargs): call_docval_func(super().__init__, kwargs) class TestDataMapScalarCompound(TestCase): def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'BazScalarCompound', BazScalarCompound) self.type_map.register_map(BazScalarCompound, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): self.baz_spec = DatasetSpec( doc='a BazScalarCompound type', dtype=[ DtypeSpec( name='id', dtype='uint64', doc='The unique identifier in this table.' ), DtypeSpec( name='attr1', dtype='text', doc='A text attribute.' ), ], name='MyBaz', data_type_def='BazScalarCompound', ) def test_construct_scalar_compound_dataset(self): """Test construct on a compound h5py.Dataset with shape (1, ) for scalar spec does not resolve the data.""" with h5py.File('test.h5', 'w') as file: comp_type = np.dtype([('id', np.uint64), ('attr1', h5py.special_dtype(vlen=str))]) test_ds = file.create_dataset( name='test_ds', data=np.array((1, 'text'), dtype=comp_type), shape=(1, ), dtype=comp_type ) expected = BazScalarCompound( name='MyBaz', data=(1, 'text'), ) builder = DatasetBuilder( name='MyBaz', data=test_ds, attributes={'data_type': 'BazScalarCompound', 'namespace': CORE_NAMESPACE, 'object_id': expected.object_id}, ) container = self.mapper.construct(builder, self.manager) self.assertEqual(type(container.data), h5py.Dataset) self.assertContainerEqual(container, expected) os.remove('test.h5') class BuildDatasetOfReferencesMixin: def setUp(self): self.setUpBazSpec() self.foo_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Foo', datasets=[ DatasetSpec(name='my_data', doc='an example dataset', dtype='int') ], attributes=[ AttributeSpec(name='attr1', doc='an example string attribute', dtype='text'), AttributeSpec(name='attr2', doc='an example int attribute', dtype='int'), AttributeSpec(name='attr3', doc='an example float attribute', dtype='float') ] ) self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') self.spec_catalog.register_spec(self.foo_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Baz', Baz) self.type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) self.type_map.register_map(Baz, ObjectMapper) self.type_map.register_map(Foo, ObjectMapper) self.manager = BuildManager(self.type_map) class TestBuildUntypedDatasetOfReferences(BuildDatasetOfReferencesMixin, TestCase): def setUpBazSpec(self): self.baz_spec = DatasetSpec( doc='a list of references to Foo objects', dtype=None, name='MyBaz', shape=[None], data_type_def='Baz', attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] ) def test_build(self): ''' Test default mapping functionality when no attributes are nested ''' foo = Foo('my_foo1', [1, 2, 3], 'string', 10) baz = Baz('MyBaz', [foo, None], 'abcdefghijklmnopqrstuvwxyz') foo_builder = self.manager.build(foo) baz_builder = self.manager.build(baz, root=True) expected = DatasetBuilder('MyBaz', [ReferenceBuilder(foo_builder), None], attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz', 'data_type': 'Baz', 'namespace': CORE_NAMESPACE, 'object_id': baz.object_id}) self.assertBuilderEqual(baz_builder, expected) class TestBuildCompoundDatasetOfReferences(BuildDatasetOfReferencesMixin, TestCase): def setUpBazSpec(self): self.baz_spec = DatasetSpec( doc='a list of references to Foo objects', dtype=[ DtypeSpec( name='id', dtype='uint64', doc='The unique identifier in this table.' ), DtypeSpec( name='foo', dtype=RefSpec('Foo', 'object'), doc='The foo in this table.' ), ], name='MyBaz', shape=[None], data_type_def='Baz', attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] ) def test_build(self): ''' Test default mapping functionality when no attributes are nested ''' foo = Foo('my_foo1', [1, 2, 3], 'string', 10) baz = Baz('MyBaz', [(1, foo)], 'abcdefghijklmnopqrstuvwxyz') foo_builder = self.manager.build(foo) baz_builder = self.manager.build(baz, root=True) expected = DatasetBuilder('MyBaz', [(1, ReferenceBuilder(foo_builder))], attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz', 'data_type': 'Baz', 'namespace': CORE_NAMESPACE, 'object_id': baz.object_id}) self.assertBuilderEqual(baz_builder, expected) class TestBuildTypedDatasetOfReferences(BuildDatasetOfReferencesMixin, TestCase): def setUpBazSpec(self): self.baz_spec = DatasetSpec( doc='a list of references to Foo objects', dtype=RefSpec('Foo', 'object'), name='MyBaz', shape=[None], data_type_def='Baz', attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] ) def test_build(self): ''' Test default mapping functionality when no attributes are nested ''' foo = Foo('my_foo1', [1, 2, 3], 'string', 10) baz = Baz('MyBaz', [foo], 'abcdefghijklmnopqrstuvwxyz') foo_builder = self.manager.build(foo) baz_builder = self.manager.build(baz, root=True) expected = DatasetBuilder('MyBaz', [ReferenceBuilder(foo_builder)], attributes={'baz_attr': 'abcdefghijklmnopqrstuvwxyz', 'data_type': 'Baz', 'namespace': CORE_NAMESPACE, 'object_id': baz.object_id}) self.assertBuilderEqual(baz_builder, expected) class TestBuildDatasetOfReferencesUnbuiltTarget(BuildDatasetOfReferencesMixin, TestCase): def setUpBazSpec(self): self.baz_spec = DatasetSpec( doc='a list of references to Foo objects', dtype=None, name='MyBaz', shape=[None], data_type_def='Baz', attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] ) def test_build(self): ''' Test default mapping functionality when no attributes are nested ''' foo = Foo('my_foo1', [1, 2, 3], 'string', 10) baz = Baz('MyBaz', [foo], 'abcdefghijklmnopqrstuvwxyz') msg = "MyBaz (MyBaz): Could not find already-built Builder for Foo 'my_foo1' in BuildManager" with self.assertRaisesWith(ReferenceTargetNotBuiltError, msg): self.manager.build(baz, root=True) class TestDataIOEdgeCases(TestCase): def setUp(self): self.setUpBazSpec() self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.baz_spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=self.spec_catalog) self.namespace_catalog = NamespaceCatalog() self.namespace_catalog.add_namespace(CORE_NAMESPACE, self.namespace) self.type_map = TypeMap(self.namespace_catalog) self.type_map.register_container_type(CORE_NAMESPACE, 'Baz', Baz) self.type_map.register_map(Baz, ObjectMapper) self.manager = BuildManager(self.type_map) self.mapper = ObjectMapper(self.baz_spec) def setUpBazSpec(self): self.baz_spec = DatasetSpec( doc='an Baz type', dtype=None, name='MyBaz', data_type_def='Baz', shape=[None], attributes=[AttributeSpec('baz_attr', 'an example string attribute', 'text')] ) def test_build_dataio(self): """Test building of a dataset with data_type and no dtype with value DataIO.""" container = Baz('my_baz', H5DataIO(['a', 'b', 'c', 'd'], chunks=True), 'value1') builder = self.type_map.build(container) self.assertIsInstance(builder.get('data'), H5DataIO) def test_build_datachunkiterator(self): """Test building of a dataset with data_type and no dtype with value DataChunkIterator.""" container = Baz('my_baz', DataChunkIterator(['a', 'b', 'c', 'd']), 'value1') builder = self.type_map.build(container) self.assertIsInstance(builder.get('data'), DataChunkIterator) def test_build_dataio_datachunkiterator(self): # hdmf#512 """Test building of a dataset with no dtype and no data_type with value DataIO wrapping a DCI.""" container = Baz('my_baz', H5DataIO(DataChunkIterator(['a', 'b', 'c', 'd']), chunks=True), 'value1') builder = self.type_map.build(container) self.assertIsInstance(builder.get('data'), H5DataIO) self.assertIsInstance(builder.get('data').data, DataChunkIterator) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1846273 hdmf-3.1.1/tests/unit/common/0000755000655200065520000000000000000000000016172 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/__init__.py0000644000655200065520000000000000000000000020271 0ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/test_alignedtable.py0000644000655200065520000010054700000000000022225 0ustar00circlecicircleciimport numpy as np from pandas.testing import assert_frame_equal import warnings from hdmf.backends.hdf5 import HDF5IO from hdmf.common import DynamicTable, VectorData, get_manager, AlignedDynamicTable, DynamicTableRegion from hdmf.testing import TestCase, remove_test_file class TestAlignedDynamicTableContainer(TestCase): """ Test the AlignedDynamicTable Container class. NOTE: Functions specific to linked tables, specifically the: * has_foreign_columns * get_foreign_columns * get_linked_tables methods are tested in the test_linkedtables.TestLinkedAlignedDynamicTables class instead of here. """ def setUp(self): warnings.simplefilter("always") # Trigger all warnings self.path = 'test_icephys_meta_intracellularrecording.h5' def tearDown(self): remove_test_file(self.path) def test_init(self): """Test that just checks that populating the tables with data works correctly""" AlignedDynamicTable( name='test_aligned_table', description='Test aligned container') def test_init_categories_without_category_tables_error(self): # Test raise error if categories is given without category_tables with self.assertRaisesWith(ValueError, "Categories provided but no category_tables given"): AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', categories=['cat1', 'cat2']) def test_init_length_mismatch_between_categories_and_category_tables(self): # Test length mismatch between categories and category_tables with self.assertRaisesWith(ValueError, "0 category_tables given but 2 categories specified"): AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', categories=['cat1', 'cat2'], category_tables=[]) def test_init_category_table_names_do_not_match_categories(self): # Construct some categories for testing category_names = ['test1', 'test2', 'test3'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=val+t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] # Test add category_table that is not listed in the categories list with self.assertRaisesWith(ValueError, "DynamicTable test3 does not appear in categories ['test1', 'test2', 't3']"): AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', categories=['test1', 'test2', 't3'], # bad name for 'test3' category_tables=categories) def test_init_duplicate_category_table_name(self): # Test duplicate table name with self.assertRaisesWith(ValueError, "Duplicate table name test1 found in input dynamic_tables"): categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=val+t, description=val+t+' description', data=np.arange(10)) for t in ['c1', 'c2', 'c3']] ) for val in ['test1', 'test1', 'test3']] AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', categories=['test1', 'test2', 'test3'], category_tables=categories) def test_init_misaligned_category_tables(self): """Test misaligned category tables""" categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=val+t, description=val+t+' description', data=np.arange(10)) for t in ['c1', 'c2', 'c3']] ) for val in ['test1', 'test2']] categories.append(DynamicTable(name='test3', description="test3 description", columns=[VectorData(name='test3 '+t, description='test3 '+t+' description', data=np.arange(8)) for t in ['c1', 'c2', 'c3']])) with self.assertRaisesWith(ValueError, "Category DynamicTable test3 does not align, it has 8 rows expected 10"): AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', categories=['test1', 'test2', 'test3'], category_tables=categories) def test_init_with_custom_empty_categories(self): """Test that we can create an empty table with custom categories""" category_names = ['test1', 'test2', 'test3'] categories = [DynamicTable(name=val, description=val+" description") for val in category_names] AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories) def test_init_with_custom_nonempty_categories(self): """Test that we can create an empty table with custom categories""" category_names = ['test1', 'test2', 'test3'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=val+t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] temp = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories) self.assertEqual(temp.categories, category_names) def test_init_with_custom_nonempty_categories_and_main(self): """ Test that we can create a non-empty table with custom non-empty categories """ category_names = ['test1', 'test2', 'test3'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] temp = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories, columns=[VectorData(name='main_' + t, description='main_'+t+'_description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']]) self.assertEqual(temp.categories, category_names) self.assertTrue('test1' in temp) # test that contains category works self.assertTrue(('test1', 'c1') in temp) # test that contains a column works # test the error case of a tuple with len !=2 with self.assertRaisesWith(ValueError, "Expected tuple of strings of length 2 got tuple of length 3"): ('test1', 'c1', 't3') in temp self.assertTupleEqual(temp.colnames, ('main_c1', 'main_c2', 'main_c3')) # confirm column names def test_init_with_custom_misaligned_categories(self): """Test that we cannot create an empty table with custom categories""" num_rows = 10 val1 = 'test1' val2 = 'test2' categories = [DynamicTable(name=val1, description=val1+" description", columns=[VectorData(name=val1+t, description=val1+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']]), DynamicTable(name=val2, description=val2+" description", columns=[VectorData(name=val2+t, description=val2+t+' description', data=np.arange(num_rows+1)) for t in ['c1', 'c2', 'c3']]) ] with self.assertRaisesWith(ValueError, "Category DynamicTable test2 does not align, it has 11 rows expected 10"): AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories) def test_init_with_duplicate_custom_categories(self): """Test that we can create an empty table with custom categories""" category_names = ['test1', 'test1'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=val+t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] with self.assertRaisesWith(ValueError, "Duplicate table name test1 found in input dynamic_tables"): AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories) def test_init_with_bad_custom_categories(self): """Test that we cannot provide a category that is not a DynamicTable""" num_rows = 10 categories = [ # good category DynamicTable(name='test1', description="test1 description", columns=[VectorData(name='test1'+t, description='test1' + t + ' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ), # use a list as a bad category example [0, 1, 2]] with self.assertRaisesWith(ValueError, "Category table with index 1 is not a DynamicTable"): AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories) def test_round_trip_container(self): """Test read and write the container by itself""" category_names = ['test1', 'test2', 'test3'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] curr = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories) with HDF5IO(self.path, manager=get_manager(), mode='w') as io: io.write(curr) with HDF5IO(self.path, manager=get_manager(), mode='r') as io: incon = io.read() self.assertListEqual(incon.categories, curr.categories) for n in category_names: assert_frame_equal(incon[n], curr[n]) def test_add_category(self): """Test that we can correct a non-empty category to an existing table""" category_names = ['test1', 'test2', 'test3'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=val+t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] adt = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories[0:2]) self.assertListEqual(adt.categories, category_names[0:2]) adt.add_category(categories[-1]) self.assertListEqual(adt.categories, category_names) def test_add_category_misaligned_rows(self): """Test that we can correct a non-empty category to an existing table""" category_names = ['test1', 'test2'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=val+t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] adt = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories) self.assertListEqual(adt.categories, category_names) with self.assertRaisesWith(ValueError, "New category DynamicTable does not align, it has 8 rows expected 10"): adt.add_category(DynamicTable(name='test3', description='test3_description', columns=[VectorData(name='test3_'+t, description='test3 '+t+' description', data=np.arange(num_rows - 2)) for t in ['c1', 'c2', 'c3'] ])) def test_add_category_already_in_table(self): category_names = ['test1', 'test2', 'test2'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=val+t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] adt = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories[0:2]) self.assertListEqual(adt.categories, category_names[0:2]) with self.assertRaisesWith(ValueError, "Category test2 already in the table"): adt.add_category(categories[-1]) def test_add_column(self): adt = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', columns=[VectorData(name='test_'+t, description='test_'+t+' description', data=np.arange(10)) for t in ['c1', 'c2', 'c3']]) # Test successful add adt.add_column(name='testA', description='testA', data=np.arange(10)) self.assertTupleEqual(adt.colnames, ('test_c1', 'test_c2', 'test_c3', 'testA')) def test_add_column_bad_category(self): """Test add column with bad category""" adt = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', columns=[VectorData(name='test_'+t, description='test_'+t+' description', data=np.arange(10)) for t in ['c1', 'c2', 'c3']]) with self.assertRaisesWith(KeyError, "'Category mycat not in table'"): adt.add_column(category='mycat', name='testA', description='testA', data=np.arange(10)) def test_add_column_bad_length(self): """Test add column that is too short""" adt = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', columns=[VectorData(name='test_'+t, description='test_'+t+' description', data=np.arange(10)) for t in ['c1', 'c2', 'c3']]) # Test successful add with self.assertRaisesWith(ValueError, "column must have the same number of rows as 'id'"): adt.add_column(name='testA', description='testA', data=np.arange(8)) def test_add_column_to_subcategory(self): """Test adding a column to a subcategory""" category_names = ['test1', 'test2', 'test3'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=val+t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] adt = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories) self.assertListEqual(adt.categories, category_names) # Test successful add adt.add_column(category='test2', name='testA', description='testA', data=np.arange(10)) self.assertTupleEqual(adt.get_category('test2').colnames, ('test2c1', 'test2c2', 'test2c3', 'testA')) def test_add_row(self): """Test adding a row to a non_empty table""" category_names = ['test1', ] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2']] ) for val in category_names] temp = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories, columns=[VectorData(name='main_' + t, description='main_'+t+'_description', data=np.arange(num_rows)) for t in ['c1', 'c2']]) self.assertListEqual(temp.categories, category_names) # Test successful add temp.add_row(test1=dict(c1=1, c2=2), main_c1=3, main_c2=5) self.assertListEqual(temp[10].iloc[0].tolist(), [3, 5, 10, 1, 2]) # Test successful add version 2 temp.add_row(data=dict(test1=dict(c1=1, c2=2), main_c1=4, main_c2=5)) self.assertListEqual(temp[11].iloc[0].tolist(), [4, 5, 11, 1, 2]) # Test missing categories data with self.assertRaises(KeyError) as ke: temp.add_row(main_c1=3, main_c2=5) self.assertTrue("row data keys do not match" in str(ke.exception)) def test_get_item(self): """Test getting elements from the table""" category_names = ['test1', ] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=t, description=val+t+' description', data=np.arange(num_rows) + i + 3) for i, t in enumerate(['c1', 'c2'])] ) for val in category_names] temp = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories, columns=[VectorData(name='main_' + t, description='main_'+t+'_description', data=np.arange(num_rows)+2) for t in ['c1', 'c2']]) self.assertListEqual(temp.categories, category_names) # Test slicing with a single index self.assertListEqual(temp[5].iloc[0].tolist(), [7, 7, 5, 8, 9]) # Test slice with list self.assertListEqual(temp[[5, 7]].iloc[0].tolist(), [7, 7, 5, 8, 9]) self.assertListEqual(temp[[5, 7]].iloc[1].tolist(), [9, 9, 7, 10, 11]) # Test slice with slice self.assertListEqual(temp[5:7].iloc[0].tolist(), [7, 7, 5, 8, 9]) self.assertListEqual(temp[5:7].iloc[1].tolist(), [8, 8, 6, 9, 10]) # Test slice with numpy index arrya self.assertListEqual(temp[np.asarray([5, 8])].iloc[0].tolist(), [7, 7, 5, 8, 9]) self.assertListEqual(temp[np.asarray([5, 8])].iloc[1].tolist(), [10, 10, 8, 11, 12]) # Test slicing for a single column self.assertListEqual(temp['main_c1'][:].tolist(), (np.arange(num_rows)+2).tolist()) # Test slicing for a single category assert_frame_equal(temp['test1'], categories[0].to_dataframe()) # Test getting the main table assert_frame_equal(temp[None], temp.to_dataframe()) # Test getting a specific column self.assertListEqual(temp['test1', 'c1'][:].tolist(), (np.arange(num_rows) + 3).tolist()) # Test getting a specific cell self.assertEqual(temp[None, 'main_c1', 1], 3) self.assertEqual(temp[1, None, 'main_c1'], 3) # Test bad selection tuple with self.assertRaisesWith(ValueError, "Expected tuple of length 2 of the form [category, column], [row, category], " "[row, (category, column)] or a tuple of length 3 of the form " "[category, column, row], [row, category, column]"): temp[('main_c1',)] # Test selecting a single cell or row of a category table by having a # [int, str] or [int, (str, str)] type selection # Select row 0 from category 'test1' re = temp[0, 'test1'] self.assertListEqual(re.columns.to_list(), ['id', 'c1', 'c2']) self.assertListEqual(re.index.names, [('test_aligned_table', 'id')]) self.assertListEqual(re.values.tolist()[0], [0, 3, 4]) # Select a single cell from a columm self.assertEqual(temp[1, ('test_aligned_table', 'main_c1')], 3) def test_to_dataframe(self): """Test that the to_dataframe method works""" category_names = ['test1', 'test2', 'test3'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] adt = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories, columns=[VectorData(name='main_' + t, description='main_'+t+'_description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']]) # Test the to_dataframe method with default settings tdf = adt.to_dataframe() self.assertListEqual(tdf.index.tolist(), list(range(10))) self.assertTupleEqual(tdf.index.name, ('test_aligned_table', 'id')) expected_cols = [('test_aligned_table', 'main_c1'), ('test_aligned_table', 'main_c2'), ('test_aligned_table', 'main_c3'), ('test1', 'id'), ('test1', 'c1'), ('test1', 'c2'), ('test1', 'c3'), ('test2', 'id'), ('test2', 'c1'), ('test2', 'c2'), ('test2', 'c3'), ('test3', 'id'), ('test3', 'c1'), ('test3', 'c2'), ('test3', 'c3')] tdf_cols = tdf.columns.tolist() for v in zip(expected_cols, tdf_cols): self.assertTupleEqual(v[0], v[1]) # test the to_dataframe method with ignore_category_ids set to True tdf = adt.to_dataframe(ignore_category_ids=True) self.assertListEqual(tdf.index.tolist(), list(range(10))) self.assertTupleEqual(tdf.index.name, ('test_aligned_table', 'id')) expected_cols = [('test_aligned_table', 'main_c1'), ('test_aligned_table', 'main_c2'), ('test_aligned_table', 'main_c3'), ('test1', 'c1'), ('test1', 'c2'), ('test1', 'c3'), ('test2', 'c1'), ('test2', 'c2'), ('test2', 'c3'), ('test3', 'c1'), ('test3', 'c2'), ('test3', 'c3')] tdf_cols = tdf.columns.tolist() for v in zip(expected_cols, tdf_cols): self.assertTupleEqual(v[0], v[1]) def test_nested_aligned_dynamic_table_not_allowed(self): """ Test that using and AlignedDynamicTable as category for an AlignedDynamicTable is not allowed """ # create an AlignedDynamicTable as category subsubcol1 = VectorData(name='sub_sub_column1', description='test sub sub column', data=['test11', 'test12']) sub_category = DynamicTable(name='sub_category1', description='test subcategory table', columns=[subsubcol1, ]) subcol1 = VectorData(name='sub_column1', description='test-subcolumn', data=['test1', 'test2']) adt_category = AlignedDynamicTable( name='category1', description='test using AlignedDynamicTable as a category', columns=[subcol1, ], category_tables=[sub_category, ]) # Create a regular column for our main AlignedDynamicTable col1 = VectorData(name='column1', description='regular test column', data=['test1', 'test2']) # test 1: Make sure we can't add the AlignedDynamicTable category on init msg = ("Category table with index %i is an AlignedDynamicTable. " "Nesting of AlignedDynamicTable is currently not supported." % 0) with self.assertRaisesWith(ValueError, msg): # create the nested AlignedDynamicTable with our adt_category as a sub-category AlignedDynamicTable( name='nested_adt', description='test nesting AlignedDynamicTable', columns=[col1, ], category_tables=[adt_category, ]) # test 2: Make sure we can't add the AlignedDynamicTable category via add_category adt = AlignedDynamicTable( name='nested_adt', description='test nesting AlignedDynamicTable', columns=[col1, ]) msg = "Category is an AlignedDynamicTable. Nesting of AlignedDynamicTable is currently not supported." with self.assertRaisesWith(ValueError, msg): adt.add_category(adt_category) def test_dynamictable_region_to_aligneddynamictable(self): """ Test to ensure data is being retrieved correctly when pointing to an AlignedDynamicTable. In particular, make sure that all columns are being used, including those of the category tables, not just the ones from the main table. """ temp_table = DynamicTable(name='t1', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', description='my test table', category_tables=[temp_table], colnames=['a1', 'a2'], columns=[VectorData(name='a1', description='c1', data=np.arange(4)), VectorData(name='a2', description='c1', data=np.arange(4))]) dtr = DynamicTableRegion(name='test', description='test', data=np.arange(4), table=temp_aligned_table) dtr_df = dtr[:] # Full number of rows self.assertEqual(len(dtr_df), 4) # Test num columns: 2 columns from the main table, 2 columns from the category, 1 id columns from the category self.assertEqual(len(dtr_df.columns), 5) # Test that the data is correct for i, v in enumerate([('my_aligned_table', 'a1'), ('my_aligned_table', 'a2'), ('t1', 'id'), ('t1', 'c1'), ('t1', 'c2')]): self.assertTupleEqual(dtr_df.columns[i], v) # Test the column data for c in dtr_df.columns: self.assertListEqual(dtr_df[c].to_list(), list(range(4))) def test_get_colnames(self): """ Test the AlignedDynamicTable.get_colnames function """ category_names = ['test1', 'test2', 'test3'] num_rows = 10 categories = [DynamicTable(name=val, description=val+" description", columns=[VectorData(name=t, description=val+t+' description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']] ) for val in category_names] adt = AlignedDynamicTable( name='test_aligned_table', description='Test aligned container', category_tables=categories, columns=[VectorData(name='main_' + t, description='main_'+t+'_description', data=np.arange(num_rows)) for t in ['c1', 'c2', 'c3']]) # Default, only get the colnames of the main table. Same as adt.colnames property expected_colnames = ('main_c1', 'main_c2', 'main_c3') self.assertTupleEqual(adt.get_colnames(), expected_colnames) # Same as default because if we don't include the catgories than ignore_category_ids has no effect self.assertTupleEqual(adt.get_colnames(include_category_tables=False, ignore_category_ids=True), expected_colnames) # Full set of columns expected_colnames = [('test_aligned_table', 'main_c1'), ('test_aligned_table', 'main_c2'), ('test_aligned_table', 'main_c3'), ('test1', 'id'), ('test1', 'c1'), ('test1', 'c2'), ('test1', 'c3'), ('test2', 'id'), ('test2', 'c1'), ('test2', 'c2'), ('test2', 'c3'), ('test3', 'id'), ('test3', 'c1'), ('test3', 'c2'), ('test3', 'c3')] self.assertListEqual(adt.get_colnames(include_category_tables=True, ignore_category_ids=False), expected_colnames) # All columns without the id columns of the category tables expected_colnames = [('test_aligned_table', 'main_c1'), ('test_aligned_table', 'main_c2'), ('test_aligned_table', 'main_c3'), ('test1', 'c1'), ('test1', 'c2'), ('test1', 'c3'), ('test2', 'c1'), ('test2', 'c2'), ('test2', 'c3'), ('test3', 'c1'), ('test3', 'c2'), ('test3', 'c3')] self.assertListEqual(adt.get_colnames(include_category_tables=True, ignore_category_ids=True), expected_colnames) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/test_common.py0000644000655200065520000000061700000000000021077 0ustar00circlecicirclecifrom hdmf import Data, Container from hdmf.common import get_type_map from hdmf.testing import TestCase class TestCommonTypeMap(TestCase): def test_base_types(self): tm = get_type_map() cls = tm.get_dt_container_cls('Container', 'hdmf-common') self.assertIs(cls, Container) cls = tm.get_dt_container_cls('Data', 'hdmf-common') self.assertIs(cls, Data) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/test_common_io.py0000644000655200065520000000572700000000000021575 0ustar00circlecicirclecifrom h5py import File from hdmf.backends.hdf5 import HDF5IO from hdmf.common import Container, get_manager from hdmf.spec import NamespaceCatalog from hdmf.testing import TestCase, remove_test_file from tests.unit.utils import get_temp_filepath class TestCacheSpec(TestCase): """Test caching spec specifically with the namespaces provided by hdmf.common. See also TestCacheSpec in tests/unit/test_io_hdf5_h5tools.py. """ def setUp(self): self.manager = get_manager() self.path = get_temp_filepath() self.container = Container('dummy') def tearDown(self): remove_test_file(self.path) def test_write_no_cache_spec(self): """Roundtrip test for not writing spec.""" with HDF5IO(self.path, manager=self.manager, mode="a") as io: io.write(self.container, cache_spec=False) with File(self.path, 'r') as f: self.assertNotIn('specifications', f) def test_write_cache_spec(self): """Roundtrip test for writing spec and reading it back in.""" with HDF5IO(self.path, manager=self.manager, mode="a") as io: io.write(self.container) with File(self.path, 'r') as f: self.assertIn('specifications', f) self._check_spec() def test_write_cache_spec_injected(self): """Roundtrip test for writing spec and reading it back in when HDF5IO is passed an open h5py.File.""" with File(self.path, 'w') as fil: with HDF5IO(self.path, manager=self.manager, file=fil, mode='a') as io: io.write(self.container) with File(self.path, 'r') as f: self.assertIn('specifications', f) self._check_spec() def _check_spec(self): ns_catalog = NamespaceCatalog() HDF5IO.load_namespaces(ns_catalog, self.path) self.maxDiff = None for namespace in self.manager.namespace_catalog.namespaces: with self.subTest(namespace=namespace): original_ns = self.manager.namespace_catalog.get_namespace(namespace) cached_ns = ns_catalog.get_namespace(namespace) ns_fields_to_check = list(original_ns.keys()) ns_fields_to_check.remove('schema') # schema fields will not match, so reset for ns_field in ns_fields_to_check: with self.subTest(namespace_field=ns_field): self.assertEqual(original_ns[ns_field], cached_ns[ns_field]) for dt in original_ns.get_registered_types(): with self.subTest(data_type=dt): original_spec = original_ns.get_spec(dt) cached_spec = cached_ns.get_spec(dt) with self.subTest('Data type spec is read back in'): self.assertIsNotNone(cached_spec) with self.subTest('Cached spec matches original spec'): self.assertDictEqual(original_spec, cached_spec) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/test_generate_table.py0000644000655200065520000002630500000000000022552 0ustar00circlecicircleciimport numpy as np import os import shutil import tempfile from hdmf.backends.hdf5 import HDF5IO from hdmf.build import BuildManager, TypeMap from hdmf.common import get_type_map, DynamicTable from hdmf.spec import GroupSpec, DatasetSpec, SpecCatalog, SpecNamespace, NamespaceCatalog from hdmf.testing import TestCase from hdmf.validate import ValidatorMap from tests.unit.utils import CORE_NAMESPACE class TestDynamicDynamicTable(TestCase): def setUp(self): self.dt_spec = GroupSpec( 'A test extension that contains a dynamic table', data_type_def='TestTable', data_type_inc='DynamicTable', datasets=[ DatasetSpec( data_type_inc='VectorData', name='my_col', doc='a test column', dtype='float' ), DatasetSpec( data_type_inc='VectorData', name='indexed_col', doc='a test column', dtype='float' ), DatasetSpec( data_type_inc='VectorIndex', name='indexed_col_index', doc='a test column', ), DatasetSpec( data_type_inc='VectorData', name='optional_col1', doc='a test column', dtype='float', quantity='?', ), DatasetSpec( data_type_inc='VectorData', name='optional_col2', doc='a test column', dtype='float', quantity='?', ) ] ) self.dt_spec2 = GroupSpec( 'A test extension that contains a dynamic table', data_type_def='TestDTRTable', data_type_inc='DynamicTable', datasets=[ DatasetSpec( data_type_inc='DynamicTableRegion', name='ref_col', doc='a test column', ), DatasetSpec( data_type_inc='DynamicTableRegion', name='indexed_ref_col', doc='a test column', ), DatasetSpec( data_type_inc='VectorIndex', name='indexed_ref_col_index', doc='a test column', ), DatasetSpec( data_type_inc='DynamicTableRegion', name='optional_ref_col', doc='a test column', quantity='?' ), DatasetSpec( data_type_inc='DynamicTableRegion', name='optional_indexed_ref_col', doc='a test column', quantity='?' ), DatasetSpec( data_type_inc='VectorIndex', name='optional_indexed_ref_col_index', doc='a test column', quantity='?' ), DatasetSpec( data_type_inc='VectorData', name='optional_col3', doc='a test column', dtype='float', quantity='?', ) ] ) from hdmf.spec.write import YAMLSpecWriter writer = YAMLSpecWriter(outdir='.') self.spec_catalog = SpecCatalog() self.spec_catalog.register_spec(self.dt_spec, 'test.yaml') self.spec_catalog.register_spec(self.dt_spec2, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [ dict( namespace='hdmf-common', ), dict(source='test.yaml'), ], version='0.1.0', catalog=self.spec_catalog ) self.test_dir = tempfile.mkdtemp() spec_fpath = os.path.join(self.test_dir, 'test.yaml') namespace_fpath = os.path.join(self.test_dir, 'test-namespace.yaml') writer.write_spec(dict(groups=[self.dt_spec, self.dt_spec2]), spec_fpath) writer.write_namespace(self.namespace, namespace_fpath) self.namespace_catalog = NamespaceCatalog() hdmf_typemap = get_type_map() self.type_map = TypeMap(self.namespace_catalog) self.type_map.merge(hdmf_typemap, ns_catalog=True) self.type_map.load_namespaces(namespace_fpath) self.manager = BuildManager(self.type_map) self.TestTable = self.type_map.get_dt_container_cls('TestTable', CORE_NAMESPACE) self.TestDTRTable = self.type_map.get_dt_container_cls('TestDTRTable', CORE_NAMESPACE) def tearDown(self) -> None: shutil.rmtree(self.test_dir) def test_dynamic_table(self): assert issubclass(self.TestTable, DynamicTable) assert self.TestTable.__columns__[0] == dict( name='my_col', description='a test column', required=True ) def test_forbids_incorrect_col(self): test_table = self.TestTable(name='test_table', description='my test table') with self.assertRaises(ValueError): test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], incorrect_col=5) def test_dynamic_column(self): test_table = self.TestTable(name='test_table', description='my test table') test_table.add_column('dynamic_column', 'this is a dynamic column') test_table.add_row( my_col=3.0, indexed_col=[1.0, 3.0], dynamic_column=4, optional_col2=.5, ) test_table.add_row( my_col=4.0, indexed_col=[2.0, 4.0], dynamic_column=4, optional_col2=.5, ) np.testing.assert_array_equal(test_table['indexed_col'].target.data, [1., 3., 2., 4.]) np.testing.assert_array_equal(test_table['dynamic_column'].data, [4, 4]) def test_optional_col(self): test_table = self.TestTable(name='test_table', description='my test table') test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) def test_dynamic_table_region(self): test_table = self.TestTable(name='test_table', description='my test table') test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table', target_tables={'ref_col': test_table, 'indexed_ref_col': test_table}) self.assertIs(test_dtr_table['ref_col'].table, test_table) self.assertIs(test_dtr_table['indexed_ref_col'].target.table, test_table) test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1]) test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1]) np.testing.assert_array_equal(test_dtr_table['indexed_ref_col'].target.data, [0, 1, 0, 1]) np.testing.assert_array_equal(test_dtr_table['ref_col'].data, [0, 0]) def test_dynamic_table_region_optional(self): test_table = self.TestTable(name='test_table', description='my test table') test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) test_dtr_table = self.TestDTRTable(name='test_dtr_table', description='my table', target_tables={'optional_ref_col': test_table, 'optional_indexed_ref_col': test_table}) self.assertIs(test_dtr_table['optional_ref_col'].table, test_table) self.assertIs(test_dtr_table['optional_indexed_ref_col'].target.table, test_table) test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_ref_col=0, optional_indexed_ref_col=[0, 1]) test_dtr_table.add_row(ref_col=0, indexed_ref_col=[0, 1], optional_ref_col=0, optional_indexed_ref_col=[0, 1]) np.testing.assert_array_equal(test_dtr_table['optional_indexed_ref_col'].target.data, [0, 1, 0, 1]) np.testing.assert_array_equal(test_dtr_table['optional_ref_col'].data, [0, 0]) def test_dynamic_table_region_bad_target_col(self): test_table = self.TestTable(name='test_table', description='my test table') test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) msg = r"^'bad' is not the name of a predefined column of table .*" with self.assertRaisesRegex(ValueError, msg): self.TestDTRTable(name='test_dtr_table', description='my table', target_tables={'bad': test_table}) def test_dynamic_table_region_non_dtr_target(self): test_table = self.TestTable(name='test_table', description='my test table') test_table.add_row(my_col=3.0, indexed_col=[1.0, 3.0], optional_col2=.5) test_table.add_row(my_col=4.0, indexed_col=[2.0, 4.0], optional_col2=.5) msg = "Column 'optional_col3' must be a DynamicTableRegion to have a target table." with self.assertRaisesWith(ValueError, msg): self.TestDTRTable(name='test_dtr_table', description='my table', target_tables={'optional_col3': test_table}) def test_roundtrip(self): # NOTE this does not use H5RoundTripMixin because this requires custom validation test_table = self.TestTable(name='test_table', description='my test table') test_table.add_column('dynamic_column', 'this is a dynamic column') test_table.add_row( my_col=3.0, indexed_col=[1.0, 3.0], dynamic_column=4, optional_col2=.5, ) self.filename = os.path.join(self.test_dir, 'test_TestTable.h5') with HDF5IO(self.filename, manager=self.manager, mode='w') as write_io: write_io.write(test_table, cache_spec=True) self.reader = HDF5IO(self.filename, manager=self.manager, mode='r') read_container = self.reader.read() self.assertIsNotNone(str(test_table)) # added as a test to make sure printing works self.assertIsNotNone(str(read_container)) # make sure we get a completely new object self.assertNotEqual(id(test_table), id(read_container)) # the name of the root container of a file is always 'root' (see h5tools.py ROOT_NAME) # thus, ignore the name of the container when comparing original container vs read container self.assertContainerEqual(read_container, test_table, ignore_name=True) builder = self.reader.read_builder() # TODO fix ValueError: No specification for 'Container' in namespace 'test_core' validator = ValidatorMap(self.manager.namespace_catalog.get_namespace(name=CORE_NAMESPACE)) errors = validator.validate(builder) if errors: for err in errors: raise Exception(err) self.reader.close() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/test_linkedtables.py0000644000655200065520000012700000000000000022244 0ustar00circlecicircleci""" Module for testing functions specific to tables containing DynamicTableRegion columns """ import numpy as np from hdmf.common import DynamicTable, AlignedDynamicTable, VectorData, DynamicTableRegion, VectorIndex from hdmf.testing import TestCase from hdmf.utils import docval, popargs, get_docval, call_docval_func from hdmf.common.hierarchicaltable import to_hierarchical_dataframe, drop_id_columns, flatten_column_index from pandas.testing import assert_frame_equal class DynamicTableSingleDTR(DynamicTable): """Test table class that references a single foreign table""" __columns__ = ( {'name': 'child_table_ref1', 'description': 'Column with a references to the next level in the hierarchy', 'required': True, 'index': True, 'table': True}, ) @docval({'name': 'name', 'type': str, 'doc': 'The name of the table'}, {'name': 'child_table1', 'type': DynamicTable, 'doc': 'the child DynamicTable this DynamicTableSingleDTR point to.'}, *get_docval(DynamicTable.__init__, 'id', 'columns', 'colnames')) def __init__(self, **kwargs): # Define default name and description settings kwargs['description'] = (kwargs['name'] + " DynamicTableSingleDTR") # Initialize the DynamicTable call_docval_func(super(DynamicTableSingleDTR, self).__init__, kwargs) if self['child_table_ref1'].target.table is None: self['child_table_ref1'].target.table = popargs('child_table1', kwargs) class DynamicTableMultiDTR(DynamicTable): """Test table class that references multiple related tables""" __columns__ = ( {'name': 'child_table_ref1', 'description': 'Column with a references to the next level in the hierarchy', 'required': True, 'index': True, 'table': True}, {'name': 'child_table_ref2', 'description': 'Column with a references to the next level in the hierarchy', 'required': True, 'index': True, 'table': True}, ) @docval({'name': 'name', 'type': str, 'doc': 'The name of the table'}, {'name': 'child_table1', 'type': DynamicTable, 'doc': 'the child DynamicTable this DynamicTableSingleDTR point to.'}, {'name': 'child_table2', 'type': DynamicTable, 'doc': 'the child DynamicTable this DynamicTableSingleDTR point to.'}, *get_docval(DynamicTable.__init__, 'id', 'columns', 'colnames')) def __init__(self, **kwargs): # Define default name and description settings kwargs['description'] = (kwargs['name'] + " DynamicTableSingleDTR") # Initialize the DynamicTable call_docval_func(super(DynamicTableMultiDTR, self).__init__, kwargs) if self['child_table_ref1'].target.table is None: self['child_table_ref1'].target.table = popargs('child_table1', kwargs) if self['child_table_ref2'].target.table is None: self['child_table_ref2'].target.table = popargs('child_table2', kwargs) class TestLinkedAlignedDynamicTables(TestCase): """ Test functionality specific to AlignedDynamicTables containing DynamicTableRegion columns. Since these functions only implements front-end convenient functions for DynamicTable we do not need to worry about I/O here (that is tested elsewere), but it is sufficient if we test with container class. The only time I/O becomes relevant is on read in case that, e.g., a h5py.Dataset may behave differently than a numpy array. """ def setUp(self): """ Create basic set of linked tables consisting of aligned_table | +--> category0 ---> table_level_0_0 | +--> category1 ---> table_level_0_1 """ # Level 0 0 table. I.e., first table on level 0 self.table_level0_0 = DynamicTable(name='level0_0', description="level0_0 DynamicTable") self.table_level0_0.add_row(id=10) self.table_level0_0.add_row(id=11) self.table_level0_0.add_row(id=12) self.table_level0_0.add_row(id=13) self.table_level0_0.add_column(data=['tag1', 'tag2', 'tag2', 'tag1', 'tag3', 'tag4', 'tag5'], name='tags', description='custom tags', index=[1, 2, 4, 7]) self.table_level0_0.add_column(data=np.arange(4), name='myid', description='custom ids', index=False) # Level 0 1 table. I.e., second table on level 0 self.table_level0_1 = DynamicTable(name='level0_1', description="level0_1 DynamicTable") self.table_level0_1.add_row(id=14) self.table_level0_1.add_row(id=15) self.table_level0_1.add_row(id=16) self.table_level0_1.add_row(id=17) self.table_level0_1.add_column(data=['tag1', 'tag1', 'tag2', 'tag2', 'tag3', 'tag3', 'tag4'], name='tags', description='custom tags', index=[2, 4, 6, 7]) self.table_level0_1.add_column(data=np.arange(4), name='myid', description='custom ids', index=False) # category 0 table self.category0 = DynamicTableSingleDTR(name='category0', child_table1=self.table_level0_0) self.category0.add_row(id=0, child_table_ref1=[0, ]) self.category0.add_row(id=1, child_table_ref1=[1, 2]) self.category0.add_row(id=1, child_table_ref1=[3, ]) self.category0.add_column(data=[10, 11, 12], name='filter', description='filter value', index=False) # category 1 table self.category1 = DynamicTableSingleDTR(name='category1', child_table1=self.table_level0_1) self.category1.add_row(id=0, child_table_ref1=[0, 1]) self.category1.add_row(id=1, child_table_ref1=[2, 3]) self.category1.add_row(id=1, child_table_ref1=[1, 3]) self.category1.add_column(data=[1, 2, 3], name='filter', description='filter value', index=False) # Aligned table self.aligned_table = AlignedDynamicTable(name='my_aligned_table', description='my test table', columns=[VectorData(name='a1', description='a1', data=np.arange(3)), ], colnames=['a1', ], category_tables=[self.category0, self.category1]) def tearDown(self): del self.table_level0_0 del self.table_level0_1 del self.category0 del self.category1 del self.aligned_table def test_to_hierarchical_dataframe(self): """Test that converting an AlignedDynamicTable with links works""" hier_df = to_hierarchical_dataframe(self.aligned_table) self.assertListEqual(hier_df.columns.to_list(), [('level0_0', 'id'), ('level0_0', 'tags'), ('level0_0', 'myid')]) self.assertListEqual(hier_df.index.names, [('my_aligned_table', 'id'), ('my_aligned_table', ('my_aligned_table', 'a1')), ('my_aligned_table', ('category0', 'id')), ('my_aligned_table', ('category0', 'filter')), ('my_aligned_table', ('category1', 'id')), ('my_aligned_table', ('category1', 'child_table_ref1')), ('my_aligned_table', ('category1', 'filter'))]) self.assertListEqual(hier_df.index.to_list(), [(0, 0, 0, 10, 0, (0, 1), 1), (1, 1, 1, 11, 1, (2, 3), 2), (1, 1, 1, 11, 1, (2, 3), 2), (2, 2, 1, 12, 1, (1, 3), 3)]) self.assertListEqual(hier_df[('level0_0', 'tags')].values.tolist(), [['tag1'], ['tag2'], ['tag2', 'tag1'], ['tag3', 'tag4', 'tag5']]) def test_has_foreign_columns_in_category_tables(self): """Test confirming working order for DynamicTableRegions in subtables""" self.assertTrue(self.aligned_table.has_foreign_columns()) self.assertFalse(self.aligned_table.has_foreign_columns(ignore_category_tables=True)) def test_has_foreign_columns_false(self): """Test false if there are no DynamicTableRegionColumns""" temp_table = DynamicTable(name='t1', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', description='my test table', category_tables=[temp_table], colnames=['a1', 'a2'], columns=[VectorData(name='a1', description='c1', data=np.arange(4)), VectorData(name='a2', description='c2', data=np.arange(4))]) self.assertFalse(temp_aligned_table.has_foreign_columns()) self.assertFalse(temp_aligned_table.has_foreign_columns(ignore_category_tables=True)) def test_has_foreign_column_in_main_table(self): temp_table = DynamicTable(name='t1', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', description='my test table', category_tables=[temp_table], colnames=['a1', 'a2'], columns=[VectorData(name='a1', description='c1', data=np.arange(4)), DynamicTableRegion(name='a2', description='c2', data=np.arange(4), table=temp_table)]) self.assertTrue(temp_aligned_table.has_foreign_columns()) self.assertTrue(temp_aligned_table.has_foreign_columns(ignore_category_tables=True)) def test_get_foreign_columns(self): # check without subcateogries foreign_cols = self.aligned_table.get_foreign_columns(ignore_category_tables=True) self.assertListEqual(foreign_cols, []) # check with subcateogries foreign_cols = self.aligned_table.get_foreign_columns() self.assertEqual(len(foreign_cols), 2) for i, v in enumerate([('category0', 'child_table_ref1'), ('category1', 'child_table_ref1')]): self.assertTupleEqual(foreign_cols[i], v) def test_get_foreign_columns_none(self): """Test false if there are no DynamicTableRegionColumns""" temp_table = DynamicTable(name='t1', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', description='my test table', category_tables=[temp_table], colnames=['a1', 'a2'], columns=[VectorData(name='a1', description='c1', data=np.arange(4)), VectorData(name='a2', description='c2', data=np.arange(4))]) self.assertListEqual(temp_aligned_table.get_foreign_columns(), []) self.assertListEqual(temp_aligned_table.get_foreign_columns(ignore_category_tables=True), []) def test_get_foreign_column_in_main_and_category_table(self): temp_table0 = DynamicTable(name='t0', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) temp_table = DynamicTable(name='t1', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), DynamicTableRegion(name='c2', description='c2', data=np.arange(4), table=temp_table0)]) temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', description='my test table', category_tables=[temp_table], colnames=['a1', 'a2'], columns=[VectorData(name='a1', description='c1', data=np.arange(4)), DynamicTableRegion(name='a2', description='c2', data=np.arange(4), table=temp_table)]) # We should get both the DynamicTableRegion from the main table and the category 't1' self.assertListEqual(temp_aligned_table.get_foreign_columns(), [(None, 'a2'), ('t1', 'c2')]) # We should only get the column from the main table self.assertListEqual(temp_aligned_table.get_foreign_columns(ignore_category_tables=True), [(None, 'a2')]) def test_get_linked_tables(self): # check without subcateogries linked_table = self.aligned_table.get_linked_tables(ignore_category_tables=True) self.assertListEqual(linked_table, []) # check with subcateogries linked_tables = self.aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) self.assertTupleEqual((linked_tables[0].source_table.name, linked_tables[0].source_column.name, linked_tables[0].target_table.name), ('category0', 'child_table_ref1', 'level0_0')) self.assertTupleEqual((linked_tables[1].source_table.name, linked_tables[1].source_column.name, linked_tables[1].target_table.name), ('category1', 'child_table_ref1', 'level0_1')) def test_get_linked_tables_none(self): """Test false if there are no DynamicTableRegionColumns""" temp_table = DynamicTable(name='t1', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', description='my test table', category_tables=[temp_table], colnames=['a1', 'a2'], columns=[VectorData(name='a1', description='c1', data=np.arange(4)), VectorData(name='a2', description='c2', data=np.arange(4))]) self.assertListEqual(temp_aligned_table.get_linked_tables(), []) self.assertListEqual(temp_aligned_table.get_linked_tables(ignore_category_tables=True), []) def test_get_linked_tables_complex_link(self): temp_table0 = DynamicTable(name='t0', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) temp_table = DynamicTable(name='t1', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), DynamicTableRegion(name='c2', description='c2', data=np.arange(4), table=temp_table0)]) temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', description='my test table', category_tables=[temp_table], colnames=['a1', 'a2'], columns=[VectorData(name='a1', description='c1', data=np.arange(4)), DynamicTableRegion(name='a2', description='c2', data=np.arange(4), table=temp_table)]) # NOTE: in this example templ_aligned_table both points to temp_table and at the # same time contains temp_table as a category. This could lead to temp_table # visited multiple times and we want to make sure this doesn't happen # We should get both the DynamicTableRegion from the main table and the category 't1' linked_tables = temp_aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) for i, v in enumerate([('my_aligned_table', 'a2', 't1'), ('t1', 'c2', 't0')]): self.assertTupleEqual((linked_tables[i].source_table.name, linked_tables[i].source_column.name, linked_tables[i].target_table.name), v) # Now, since our main table links to the category table the result should remain the same # even if we ignore the category table linked_tables = temp_aligned_table.get_linked_tables(ignore_category_tables=True) self.assertEqual(len(linked_tables), 2) for i, v in enumerate([('my_aligned_table', 'a2', 't1'), ('t1', 'c2', 't0')]): self.assertTupleEqual((linked_tables[i].source_table.name, linked_tables[i].source_column.name, linked_tables[i].target_table.name), v) def test_get_linked_tables_simple_link(self): temp_table0 = DynamicTable(name='t0', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), VectorData(name='c2', description='c2', data=np.arange(4))]) temp_table = DynamicTable(name='t1', description='t1', colnames=['c1', 'c2'], columns=[VectorData(name='c1', description='c1', data=np.arange(4)), DynamicTableRegion(name='c2', description='c2', data=np.arange(4), table=temp_table0)]) temp_aligned_table = AlignedDynamicTable(name='my_aligned_table', description='my test table', category_tables=[temp_table], colnames=['a1', 'a2'], columns=[VectorData(name='a1', description='c1', data=np.arange(4)), DynamicTableRegion(name='a2', description='c2', data=np.arange(4), table=temp_table0)]) # NOTE: in this example temp_aligned_table and temp_table both point to temp_table0 # We should get both the DynamicTableRegion from the main table and the category 't1' linked_tables = temp_aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) for i, v in enumerate([('my_aligned_table', 'a2', 't0'), ('t1', 'c2', 't0')]): self.assertTupleEqual((linked_tables[i].source_table.name, linked_tables[i].source_column.name, linked_tables[i].target_table.name), v) # Since no table ever link to our category temp_table we should only get the link from our # main table here, in contrast to what happens in the test_get_linked_tables_complex_link case linked_tables = temp_aligned_table.get_linked_tables() self.assertEqual(len(linked_tables), 2) for i, v in enumerate([('my_aligned_table', 'a2', 't0'), ]): self.assertTupleEqual((linked_tables[i].source_table.name, linked_tables[i].source_column.name, linked_tables[i].target_table.name), v) class TestHierarchicalTable(TestCase): def setUp(self): """ Create basic set of linked tables consisting of super_parent_table ---> parent_table ---> aligned_table | +--> category0 """ # Level 0 0 table. I.e., first table on level 0 self.category0 = DynamicTable(name='level0_0', description="level0_0 DynamicTable") self.category0.add_row(id=10) self.category0.add_row(id=11) self.category0.add_row(id=12) self.category0.add_row(id=13) self.category0.add_column(data=['tag1', 'tag2', 'tag2', 'tag1', 'tag3', 'tag4', 'tag5'], name='tags', description='custom tags', index=[1, 2, 4, 7]) self.category0.add_column(data=np.arange(4), name='myid', description='custom ids', index=False) # Aligned table self.aligned_table = AlignedDynamicTable(name='aligned_table', description='parent_table', columns=[VectorData(name='a1', description='a1', data=np.arange(4)), ], colnames=['a1', ], category_tables=[self.category0, ]) # Parent table self.parent_table = DynamicTable(name='parent_table', description='parent_table', columns=[VectorData(name='p1', description='p1', data=np.arange(4)), DynamicTableRegion(name='l1', description='l1', data=np.arange(4), table=self.aligned_table)]) # Super-parent table dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(4), table=self.parent_table) vi_dtr_sp = VectorIndex(name='sl1_index', data=[1, 2, 3], target=dtr_sp) self.super_parent_table = DynamicTable(name='super_parent_table', description='super_parent_table', columns=[VectorData(name='sp1', description='sp1', data=np.arange(3)), dtr_sp, vi_dtr_sp]) def tearDown(self): del self.category0 del self.aligned_table del self.parent_table def test_to_hierarchical_dataframe_no_dtr_on_top_level(self): # Cover the case where our top dtr is flat (i.e., without a VectorIndex) dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(4), table=self.parent_table) spttable = DynamicTable(name='super_parent_table', description='super_parent_table', columns=[VectorData(name='sp1', description='sp1', data=np.arange(4)), dtr_sp]) hier_df = to_hierarchical_dataframe(spttable).reset_index() expected_columns = [('super_parent_table', 'id'), ('super_parent_table', 'sp1'), ('parent_table', 'id'), ('parent_table', 'p1'), ('aligned_table', 'id'), ('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'id')), ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))] self.assertListEqual(hier_df.columns.to_list(), expected_columns) def test_to_hierarchical_dataframe_indexed_dtr_on_last_level(self): # Parent table dtr_p1 = DynamicTableRegion(name='l1', description='l1', data=np.arange(4), table=self.aligned_table) vi_dtr_p1 = VectorIndex(name='sl1_index', data=[1, 2, 3], target=dtr_p1) p1 = DynamicTable(name='parent_table', description='parent_table', columns=[VectorData(name='p1', description='p1', data=np.arange(3)), dtr_p1, vi_dtr_p1]) # Super-parent table dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(4), table=p1) vi_dtr_sp = VectorIndex(name='sl1_index', data=[1, 2, 3], target=dtr_sp) spt = DynamicTable(name='super_parent_table', description='super_parent_table', columns=[VectorData(name='sp1', description='sp1', data=np.arange(3)), dtr_sp, vi_dtr_sp]) hier_df = to_hierarchical_dataframe(spt).reset_index() expected_columns = [('super_parent_table', 'id'), ('super_parent_table', 'sp1'), ('parent_table', 'id'), ('parent_table', 'p1'), ('aligned_table', 'id'), ('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'id')), ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))] self.assertListEqual(hier_df.columns.to_list(), expected_columns) # make sure we have the right columns self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'tags'))].to_list(), [['tag1'], ['tag2'], ['tag2', 'tag1']]) def test_to_hierarchical_dataframe_empty_tables(self): # Setup empty tables with the following hierarchy # super_parent_table ---> parent_table ---> child_table a1 = DynamicTable(name='level0_0', description="level0_0 DynamicTable", columns=[VectorData(name='l0', description='l0', data=[])]) p1 = DynamicTable(name='parent_table', description='parent_table', columns=[DynamicTableRegion(name='l1', description='l1', data=[], table=a1), VectorData(name='p1c', description='l0', data=[])]) dtr_sp = DynamicTableRegion(name='sl1', description='sl1', data=np.arange(4), table=p1) vi_dtr_sp = VectorIndex(name='sl1_index', data=[], target=dtr_sp) spt = DynamicTable(name='super_parent_table', description='super_parent_table', columns=[dtr_sp, vi_dtr_sp, VectorData(name='sptc', description='l0', data=[])]) # Convert to hierarchical dataframe and make sure we get the right columns hier_df = to_hierarchical_dataframe(spt).reset_index() expected_columns = [('super_parent_table', 'id'), ('super_parent_table', 'sptc'), ('parent_table', 'id'), ('parent_table', 'p1c'), ('level0_0', 'id'), ('level0_0', 'l0')] self.assertListEqual(hier_df.columns.to_list(), expected_columns) def test_to_hierarchical_dataframe_multilevel(self): hier_df = to_hierarchical_dataframe(self.super_parent_table).reset_index() expected_cols = [('super_parent_table', 'id'), ('super_parent_table', 'sp1'), ('parent_table', 'id'), ('parent_table', 'p1'), ('aligned_table', 'id'), ('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'id')), ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))] # Check that we have all the columns self.assertListEqual(hier_df.columns.to_list(), expected_cols) # Spot-check the data in two columns self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'tags'))].to_list(), [['tag1'], ['tag2'], ['tag2', 'tag1']]) self.assertListEqual(hier_df[('aligned_table', ('aligned_table', 'a1'))].to_list(), list(range(3))) def test_to_hierarchical_dataframe(self): hier_df = to_hierarchical_dataframe(self.parent_table) self.assertEqual(len(hier_df), 4) self.assertEqual(len(hier_df.columns), 5) self.assertEqual(len(hier_df.index.names), 2) columns = [('aligned_table', 'id'), ('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'id')), ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))] for i, c in enumerate(hier_df.columns): self.assertTupleEqual(c, columns[i]) index_names = [('parent_table', 'id'), ('parent_table', 'p1')] self.assertListEqual(hier_df.index.names, index_names) self.assertListEqual(hier_df.index.to_list(), [(i, i) for i in range(4)]) self.assertListEqual(hier_df[('aligned_table', ('aligned_table', 'a1'))].to_list(), list(range(4))) self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'id'))].to_list(), list(range(10, 14))) self.assertListEqual(hier_df[('aligned_table', ('level0_0', 'myid'))].to_list(), list(range(4))) tags = [['tag1'], ['tag2'], ['tag2', 'tag1'], ['tag3', 'tag4', 'tag5']] for i, v in enumerate(hier_df[('aligned_table', ('level0_0', 'tags'))].to_list()): self.assertListEqual(v, tags[i]) def test_to_hierarchical_dataframe_flat_table(self): hier_df = to_hierarchical_dataframe(self.category0) assert_frame_equal(hier_df, self.category0.to_dataframe()) hier_df = to_hierarchical_dataframe(self.aligned_table) assert_frame_equal(hier_df, self.aligned_table.to_dataframe()) def test_drop_id_columns(self): hier_df = to_hierarchical_dataframe(self.parent_table) cols = hier_df.columns.to_list() mod_df = drop_id_columns(hier_df, inplace=False) expected_cols = [('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))] self.assertListEqual(hier_df.columns.to_list(), cols) # Test that no columns are dropped with inplace=False self.assertListEqual(mod_df.columns.to_list(), expected_cols) # Assert that we got back a modified dataframe drop_id_columns(hier_df, inplace=True) self.assertListEqual(hier_df.columns.to_list(), expected_cols) flat_df = to_hierarchical_dataframe(self.parent_table).reset_index(inplace=False) drop_id_columns(flat_df, inplace=True) self.assertListEqual(flat_df.columns.to_list(), [('parent_table', 'p1'), ('aligned_table', ('aligned_table', 'a1')), ('aligned_table', ('level0_0', 'tags')), ('aligned_table', ('level0_0', 'myid'))]) def test_flatten_column_index(self): hier_df = to_hierarchical_dataframe(self.parent_table).reset_index() cols = hier_df.columns.to_list() expexted_cols = [('parent_table', 'id'), ('parent_table', 'p1'), ('aligned_table', 'id'), ('aligned_table', 'aligned_table', 'a1'), ('aligned_table', 'level0_0', 'id'), ('aligned_table', 'level0_0', 'tags'), ('aligned_table', 'level0_0', 'myid')] df = flatten_column_index(hier_df, inplace=False) # Test that our columns have not changed with inplace=False self.assertListEqual(hier_df.columns.to_list(), cols) self.assertListEqual(df.columns.to_list(), expexted_cols) # make sure we got back a modified dataframe flatten_column_index(hier_df, inplace=True) # make sure we can also directly flatten inplace self.assertListEqual(hier_df.columns.to_list(), expexted_cols) # Test that we can apply flatten_column_index again on our already modified dataframe to reduce the levels flatten_column_index(hier_df, inplace=True, max_levels=2) expexted_cols = [('parent_table', 'id'), ('parent_table', 'p1'), ('aligned_table', 'id'), ('aligned_table', 'a1'), ('level0_0', 'id'), ('level0_0', 'tags'), ('level0_0', 'myid')] self.assertListEqual(hier_df.columns.to_list(), expexted_cols) # Test that we can directly reduce the max_levels to just 1 hier_df = to_hierarchical_dataframe(self.parent_table).reset_index() flatten_column_index(hier_df, inplace=True, max_levels=1) expexted_cols = ['id', 'p1', 'id', 'a1', 'id', 'tags', 'myid'] self.assertListEqual(hier_df.columns.to_list(), expexted_cols) def test_flatten_column_index_already_flat_index(self): hier_df = to_hierarchical_dataframe(self.parent_table).reset_index() flatten_column_index(hier_df, inplace=True, max_levels=1) expexted_cols = ['id', 'p1', 'id', 'a1', 'id', 'tags', 'myid'] self.assertListEqual(hier_df.columns.to_list(), expexted_cols) # Now try to flatten the already flat columns again to make sure nothing changes flatten_column_index(hier_df, inplace=True, max_levels=1) self.assertListEqual(hier_df.columns.to_list(), expexted_cols) def test_flatten_column_index_bad_maxlevels(self): hier_df = to_hierarchical_dataframe(self.parent_table) with self.assertRaisesWith(ValueError, 'max_levels must be greater than 0'): flatten_column_index(dataframe=hier_df, inplace=True, max_levels=-1) with self.assertRaisesWith(ValueError, 'max_levels must be greater than 0'): flatten_column_index(dataframe=hier_df, inplace=True, max_levels=0) class TestLinkedDynamicTables(TestCase): """ Test functionality specific to DynamicTables containing DynamicTableRegion columns. Since these functions only implements front-end convenient functions for DynamicTable we do not need to worry about I/O here (that is tested elsewere), ut it is sufficient if we test with container class. The only time I/O becomes relevant is on read in case that, e.g., a h5py.Dataset may behave differently than a numpy array. """ def setUp(self): """ Create basic set of linked tables consisting of table_level2 ---> table_level1 ----> table_level_0_0 \ ------> table_level_0_1 """ self.table_level0_0 = DynamicTable(name='level0_0', description="level0_0 DynamicTable") self.table_level0_1 = DynamicTable(name='level0_1', description="level0_1 DynamicTable") self.table_level1 = DynamicTableMultiDTR(name='level1', child_table1=self.table_level0_0, child_table2=self.table_level0_1) self.table_level2 = DynamicTableSingleDTR(name='level2', child_table1=self.table_level1) def tearDown(self): del self.table_level0_0 del self.table_level0_1 del self.table_level1 del self.table_level2 def popolate_tables(self): """Helper function to populate our tables generate in setUp with some simple data""" # Level 0 0 table. I.e., first table on level 0 self.table_level0_0.add_row(id=10) self.table_level0_0.add_row(id=11) self.table_level0_0.add_row(id=12) self.table_level0_0.add_row(id=13) self.table_level0_0.add_column(data=['tag1', 'tag2', 'tag2', 'tag1', 'tag3', 'tag4', 'tag5'], name='tags', description='custom tags', index=[1, 2, 4, 7]) self.table_level0_0.add_column(data=np.arange(4), name='myid', description='custom ids', index=False) # Level 0 1 table. I.e., second table on level 0 self.table_level0_1.add_row(id=14) self.table_level0_1.add_row(id=15) self.table_level0_1.add_row(id=16) self.table_level0_1.add_row(id=17) self.table_level0_1.add_column(data=['tag1', 'tag1', 'tag2', 'tag2', 'tag3', 'tag3', 'tag4'], name='tags', description='custom tags', index=[2, 4, 6, 7]) self.table_level0_1.add_column(data=np.arange(4), name='myid', description='custom ids', index=False) # Level 1 table self.table_level1.add_row(id=0, child_table_ref1=[0, 1], child_table_ref2=[0]) self.table_level1.add_row(id=1, child_table_ref1=[2], child_table_ref2=[1, 2]) self.table_level1.add_row(id=2, child_table_ref1=[3], child_table_ref2=[3]) self.table_level1.add_column(data=['tag1', 'tag2', 'tag2'], name='tag', description='custom tag', index=False) self.table_level1.add_column(data=['tag1', 'tag2', 'tag2', 'tag3', 'tag3', 'tag4', 'tag5'], name='tags', description='custom tags', index=[2, 4, 7]) # Level 2 data self.table_level2.add_row(id=0, child_table_ref1=[0, ]) self.table_level2.add_row(id=1, child_table_ref1=[1, 2]) self.table_level2.add_column(data=[10, 12], name='filter', description='filter value', index=False) def test_populate_table_hierarchy(self): """Test that just checks that populating the tables with data works correctly""" self.popolate_tables() # Check level0 0 data self.assertListEqual(self.table_level0_0.id[:], np.arange(10, 14, 1).tolist()) self.assertListEqual(self.table_level0_0['tags'][:], [['tag1'], ['tag2'], ['tag2', 'tag1'], ['tag3', 'tag4', 'tag5']]) self.assertListEqual(self.table_level0_0['myid'][:].tolist(), np.arange(0, 4, 1).tolist()) # Check level0 1 data self.assertListEqual(self.table_level0_1.id[:], np.arange(14, 18, 1).tolist()) self.assertListEqual(self.table_level0_1['tags'][:], [['tag1', 'tag1'], ['tag2', 'tag2'], ['tag3', 'tag3'], ['tag4']]) self.assertListEqual(self.table_level0_1['myid'][:].tolist(), np.arange(0, 4, 1).tolist()) # Check level1 data self.assertListEqual(self.table_level1.id[:], np.arange(0, 3, 1).tolist()) self.assertListEqual(self.table_level1['tag'][:], ['tag1', 'tag2', 'tag2']) self.assertTrue(self.table_level1['child_table_ref1'].target.table is self.table_level0_0) self.assertTrue(self.table_level1['child_table_ref2'].target.table is self.table_level0_1) self.assertEqual(len(self.table_level1['child_table_ref1'].target.table), 4) self.assertEqual(len(self.table_level1['child_table_ref2'].target.table), 4) # Check level2 data self.assertListEqual(self.table_level2.id[:], np.arange(0, 2, 1).tolist()) self.assertListEqual(self.table_level2['filter'][:], [10, 12]) self.assertTrue(self.table_level2['child_table_ref1'].target.table is self.table_level1) self.assertEqual(len(self.table_level2['child_table_ref1'].target.table), 3) def test_get_foreign_columns(self): """Test DynamicTable.get_foreign_columns""" self.popolate_tables() self.assertListEqual(self.table_level0_0.get_foreign_columns(), []) self.assertListEqual(self.table_level0_1.get_foreign_columns(), []) self.assertListEqual(self.table_level1.get_foreign_columns(), ['child_table_ref1', 'child_table_ref2']) self.assertListEqual(self.table_level2.get_foreign_columns(), ['child_table_ref1']) def test_has_foreign_columns(self): """Test DynamicTable.get_foreign_columns""" self.popolate_tables() self.assertFalse(self.table_level0_0.has_foreign_columns()) self.assertFalse(self.table_level0_1.has_foreign_columns()) self.assertTrue(self.table_level1.has_foreign_columns()) self.assertTrue(self.table_level2.has_foreign_columns()) def test_get_linked_tables(self): """Test DynamicTable.get_linked_tables""" self.popolate_tables() # check level0_0 self.assertListEqual(self.table_level0_0.get_linked_tables(), []) # check level0_0 self.assertListEqual(self.table_level0_1.get_linked_tables(), []) # check level1 temp = self.table_level1.get_linked_tables() self.assertEqual(len(temp), 2) self.assertEqual(temp[0].source_table.name, self.table_level1.name) self.assertEqual(temp[0].source_column.name, 'child_table_ref1') self.assertEqual(temp[0].target_table.name, self.table_level0_0.name) self.assertEqual(temp[1].source_table.name, self.table_level1.name) self.assertEqual(temp[1].source_column.name, 'child_table_ref2') self.assertEqual(temp[1].target_table.name, self.table_level0_1.name) # check level2 temp = self.table_level2.get_linked_tables() self.assertEqual(len(temp), 3) self.assertEqual(temp[0].source_table.name, self.table_level2.name) self.assertEqual(temp[0].source_column.name, 'child_table_ref1') self.assertEqual(temp[0].target_table.name, self.table_level1.name) self.assertEqual(temp[1].source_table.name, self.table_level1.name) self.assertEqual(temp[1].source_column.name, 'child_table_ref1') self.assertEqual(temp[1].target_table.name, self.table_level0_0.name) self.assertEqual(temp[2].source_table.name, self.table_level1.name) self.assertEqual(temp[2].source_column.name, 'child_table_ref2') self.assertEqual(temp[2].target_table.name, self.table_level0_1.name) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/test_multi.py0000644000655200065520000000103600000000000020735 0ustar00circlecicirclecifrom hdmf.common import SimpleMultiContainer from hdmf.container import Container, Data from hdmf.testing import TestCase, H5RoundTripMixin class SimpleMultiContainerRoundTrip(H5RoundTripMixin, TestCase): def setUpContainer(self): containers = [ Container('container1'), Container('container2'), Data('data1', [0, 1, 2, 3, 4]), Data('data2', [0.0, 1.0, 2.0, 3.0, 4.0]), ] multi_container = SimpleMultiContainer('multi', containers) return multi_container ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/test_resources.py0000644000655200065520000004125300000000000021622 0ustar00circlecicircleciimport pandas as pd from hdmf.common.resources import ExternalResources, Key, Resource from hdmf import Data from hdmf.testing import TestCase, H5RoundTripMixin import numpy as np import unittest class TestExternalResources(H5RoundTripMixin, TestCase): def setUpContainer(self): er = ExternalResources('terms') er.add_ref( container='uuid1', field='field1', key='key1', resource_name='resource11', resource_uri='resource_uri11', entity_id="id11", entity_uri='url11') er.add_ref( container='uuid2', field='field2', key='key2', resource_name='resource21', resource_uri='resource_uri21', entity_id="id12", entity_uri='url21') return er @unittest.skip('Outdated do to privatization') def test_piecewise_add(self): er = ExternalResources('terms') # this is the term the user wants to use. They will need to specify this key = er._add_key('mouse') resource1 = er._add_resource(resource='resource0', uri='resource_uri0') # the user will have to supply this info as well. This is the information # needed to retrieve info about the controled term er._add_entity(key, resource1, '10090', 'uri') # The user can also pass in the container or it can be wrapped up under NWBFILE obj = er._add_object('object', 'species') # This could also be wrapped up under NWBFile er._add_object_key(obj, key) self.assertEqual(er.keys.data, [('mouse',)]) self.assertEqual(er.entities.data, [(0, 0, '10090', 'uri')]) self.assertEqual(er.objects.data, [('object', 'species')]) def test_add_ref(self): er = ExternalResources('terms') data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) er.add_ref( container=data, field='', key='key1', resource_name='resource1', resource_uri='uri1', entity_id='entity_id1', entity_uri='entity1') self.assertEqual(er.keys.data, [('key1',)]) self.assertEqual(er.resources.data, [('resource1', 'uri1')]) self.assertEqual(er.entities.data, [(0, 0, 'entity_id1', 'entity1')]) self.assertEqual(er.objects.data, [(data.object_id, '')]) def test_add_ref_duplicate_resource(self): er = ExternalResources('terms') er.add_ref( container='uuid1', field='field1', key='key1', resource_name='resource0', resource_uri='uri0', entity_id='entity_id1', entity_uri='entity1') resource_list = er.resources.which(resource='resource0') self.assertEqual(len(resource_list), 1) def test_add_ref_bad_arg(self): er = ExternalResources('terms') resource1 = er._add_resource(resource='resource0', uri='resource_uri0') # The contents of the message are not important. Just make sure an error is raised with self.assertRaises(ValueError): er.add_ref( 'uuid1', 'field1', 'key1', resource_name='resource1', resource_uri='uri1', entity_id='resource_id1') with self.assertRaises(ValueError): er.add_ref('uuid1', 'field1', 'key1', resource_name='resource1', resource_uri='uri1', entity_uri='uri1') with self.assertRaises(ValueError): er.add_ref('uuid1', 'field1', 'key1', resource_name='resource1', resource_uri='uri1') with self.assertRaises(TypeError): er.add_ref('uuid1', 'field1') with self.assertRaises(ValueError): er.add_ref('uuid1', 'field1', 'key1', resource_name='resource1') with self.assertRaises(ValueError): er.add_ref( 'uuid1', 'field1', 'key1', resources_idx=resource1, resource_name='resource1', resource_uri='uri1') def test_add_ref_two_resources(self): er = ExternalResources('terms') er.add_ref( container='uuid1', field='field1', key='key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') er.add_ref( container='uuid1', field='field1', key=er.get_key(key_name='key1'), resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') self.assertEqual(er.keys.data, [('key1',)]) self.assertEqual(er.resources.data, [('resource1', 'resource_uri1'), ('resource2', 'resource_uri2')]) self.assertEqual(er.objects.data, [('uuid1', 'field1')]) self.assertEqual(er.entities.data, [(0, 0, 'id11', 'url11'), (0, 1, 'id12', 'url21')]) def test_get_resources(self): er = ExternalResources('terms') er.add_ref( container='uuid1', field='field1', key='key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') resource = er.get_resource('resource1') self.assertIsInstance(resource, Resource) with self.assertRaises(ValueError): er.get_resource('unknown_resource') def test_add_ref_two_keys(self): er = ExternalResources('terms') er.add_ref( container='uuid1', field='field1', key='key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') er.add_ref( container='uuid2', field='field2', key='key2', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') self.assertEqual(er.keys.data, [('key1',), ('key2',)]) self.assertEqual(er.resources.data, [('resource1', 'resource_uri1'), ('resource2', 'resource_uri2')]) self.assertEqual(er.entities.data, [(0, 0, 'id11', 'url11'), (1, 1, 'id12', 'url21')]) self.assertEqual(er.objects.data, [('uuid1', 'field1'), ('uuid2', 'field2')]) def test_add_ref_same_key_diff_objfield(self): er = ExternalResources('terms') er.add_ref( container='uuid1', field='field1', key='key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') er.add_ref( container='uuid2', field='field2', key='key1', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') self.assertEqual(er.keys.data, [('key1',), ('key1',)]) self.assertEqual(er.entities.data, [(0, 0, 'id11', 'url11'), (1, 1, 'id12', 'url21')]) self.assertEqual(er.resources.data, [('resource1', 'resource_uri1'), ('resource2', 'resource_uri2')]) self.assertEqual(er.objects.data, [('uuid1', 'field1'), ('uuid2', 'field2')]) def test_add_ref_same_keyname(self): er = ExternalResources('terms') er.add_ref( container='uuid1', field='field1', key='key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') er.add_ref( container='uuid2', field='field2', key='key1', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') er.add_ref( container='uuid3', field='field3', key='key1', resource_name='resource3', resource_uri='resource_uri3', entity_id="id13", entity_uri='url31') self.assertEqual(er.keys.data, [('key1',), ('key1',), ('key1',)]) self.assertEqual(er.resources.data, [('resource1', 'resource_uri1'), ('resource2', 'resource_uri2'), ('resource3', 'resource_uri3')]) self.assertEqual( er.entities.data, [(0, 0, 'id11', 'url11'), (1, 1, 'id12', 'url21'), (2, 2, 'id13', 'url31')]) self.assertEqual(er.objects.data, [('uuid1', 'field1'), ('uuid2', 'field2'), ('uuid3', 'field3')]) def test_get_keys(self): er = ExternalResources('terms') er.add_ref( container='uuid1', field='field1', key='key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') er.add_ref( container='uuid2', field='field2', key='key2', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') er.add_ref( container='uuid1', field='field1', key=er.get_key(key_name='key1'), resource_name='resource3', resource_uri='resource_uri3', entity_id="id13", entity_uri='url31') received = er.get_keys() expected = pd.DataFrame( data=[['key1', 0, 'id11', 'url11'], ['key1', 2, 'id13', 'url31'], ['key2', 1, 'id12', 'url21']], columns=['key_name', 'resources_idx', 'entity_id', 'entity_uri']) pd.testing.assert_frame_equal(received, expected) def test_get_keys_subset(self): er = ExternalResources('terms') er.add_ref( container='uuid1', field='field1', key='key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') er.add_ref( container='uuid2', field='field2', key='key2', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') er.add_ref( container='uuid1', field='field1', key=er.get_key(key_name='key1'), resource_name='resource3', resource_uri='resource_uri3', entity_id="id13", entity_uri='url31') key = er.keys.row[0] received = er.get_keys(keys=key) expected = pd.DataFrame( data=[['key1', 0, 'id11', 'url11'], ['key1', 2, 'id13', 'url31']], columns=['key_name', 'resources_idx', 'entity_id', 'entity_uri']) pd.testing.assert_frame_equal(received, expected) def test_get_object_resources(self): er = ExternalResources('terms') data = Data(name='data_name', data=np.array([('Mus musculus', 9, 81.0), ('Homo sapien', 3, 27.0)], dtype=[('species', 'U14'), ('age', 'i4'), ('weight', 'f4')])) er.add_ref(container=data, field='data/species', key='Mus musculus', resource_name='NCBI_Taxonomy', resource_uri='https://www.ncbi.nlm.nih.gov/taxonomy', entity_id='NCBI:txid10090', entity_uri='https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=10090') received = er.get_object_resources(data, 'data/species') expected = pd.DataFrame( data=[[0, 0, 'NCBI:txid10090', 'https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=10090']], columns=['keys_idx', 'resource_idx', 'entity_id', 'entity_uri']) pd.testing.assert_frame_equal(received, expected) def test_object_key_unqiueness(self): er = ExternalResources('terms') data = Data(name='data_name', data=np.array([('Mus musculus', 9, 81.0), ('Homo sapien', 3, 27.0)], dtype=[('species', 'U14'), ('age', 'i4'), ('weight', 'f4')])) er.add_ref(container=data, field='data/species', key='Mus musculus', resource_name='NCBI_Taxonomy', resource_uri='https://www.ncbi.nlm.nih.gov/taxonomy', entity_id='NCBI:txid10090', entity_uri='https://www.ncbi.nlm.nih.gov/Taxonomy/Browser/wwwtax.cgi?id=10090') existing_key = er.get_key('Mus musculus') er.add_ref(container=data, field='data/species', key=existing_key, resource_name='resource2', resource_uri='resource_uri2', entity_id='entity2', entity_uri='entity_uri2') self.assertEqual(er.object_keys.data, [(0, 0)]) def test_check_object_field_add(self): er = ExternalResources('terms') data = Data(name="species", data=['Homo sapiens', 'Mus musculus']) er._check_object_field('uuid1', 'field1') er._check_object_field(data, 'field2') self.assertEqual(er.objects.data, [('uuid1', 'field1'), (data.object_id, 'field2')]) class TestExternalResourcesGetKey(TestCase): def setUp(self): self.er = ExternalResources('terms') def test_get_key(self): self.er.add_ref( 'uuid1', 'field1', 'key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') self.er.add_ref( 'uuid2', 'field2', 'key1', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') keys = self.er.get_key('key1', 'uuid2', 'field2') self.assertIsInstance(keys, Key) self.assertEqual(keys.idx, 1) def test_get_key_bad_arg(self): self.er._add_key('key2') self.er.add_ref( 'uuid1', 'field1', 'key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') with self.assertRaises(ValueError): self.er.get_key('key2', 'uuid1', 'field1') @unittest.skip('Outdated do to privatization') def test_get_key_without_container(self): self.er = ExternalResources('terms') self.er._add_key('key1') keys = self.er.get_key('key1') self.assertIsInstance(keys, Key) def test_get_key_w_object_info(self): self.er.add_ref( 'uuid1', 'field1', 'key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') self.er.add_ref( 'uuid2', 'field2', 'key1', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') keys = self.er.get_key('key1', 'uuid1', 'field1') self.assertIsInstance(keys, Key) self.assertEqual(keys.key, 'key1') def test_get_key_w_bad_object_info(self): self.er.add_ref( 'uuid1', 'field1', 'key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') self.er.add_ref( 'uuid2', 'field2', 'key1', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') with self.assertRaisesRegex(ValueError, "No key with name 'key2'"): self.er.get_key('key2', 'uuid1', 'field1') def test_get_key_doesnt_exist(self): self.er.add_ref( 'uuid1', 'field1', 'key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') self.er.add_ref( 'uuid2', 'field2', 'key1', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url21') with self.assertRaisesRegex(ValueError, "key 'bad_key' does not exist"): self.er.get_key('bad_key') @unittest.skip('Outdated do to privatization') def test_get_key_same_keyname_all(self): self.er = ExternalResources('terms') key1 = self.er._add_key('key1') key2 = self.er._add_key('key1') self.er.add_ref( 'uuid1', 'field1', key1, resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') self.er.add_ref( 'uuid2', 'field2', key2, resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url12') self.er.add_ref( 'uuid1', 'field1', self.er.get_key('key1', 'uuid1', 'field1'), resource_name='resource3', resource_uri='resource_uri3', entity_id="id13", entity_uri='url13') keys = self.er.get_key('key1') self.assertIsInstance(keys, Key) self.assertEqual(keys[0].key, 'key1') self.assertEqual(keys[1].key, 'key1') def test_get_key_same_keyname_specific(self): self.er = ExternalResources('terms') self.er.add_ref( 'uuid1', 'field1', 'key1', resource_name='resource1', resource_uri='resource_uri1', entity_id="id11", entity_uri='url11') self.er.add_ref( 'uuid2', 'field2', 'key2', resource_name='resource2', resource_uri='resource_uri2', entity_id="id12", entity_uri='url12') self.er.add_ref( 'uuid1', 'field1', self.er.get_key('key1', 'uuid1', 'field1'), resource_name='resource3', resource_uri='resource_uri3', entity_id="id13", entity_uri='url13') keys = self.er.get_key('key1', 'uuid1', 'field1') self.assertIsInstance(keys, Key) self.assertEqual(keys.key, 'key1') self.assertEqual(self.er.keys.data, [('key1',), ('key2',)]) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/test_sparse.py0000644000655200065520000001012000000000000021072 0ustar00circlecicircleciimport numpy as np import scipy.sparse as sps from hdmf.common import CSRMatrix from hdmf.testing import TestCase, H5RoundTripMixin class TestCSRMatrix(TestCase): def test_from_sparse_matrix(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) shape = (3, 3) expected = CSRMatrix(data, indices, indptr, shape) sps_mat = sps.csr_matrix((data, indices, indptr), shape=shape) received = CSRMatrix(sps_mat) self.assertContainerEqual(received, expected, ignore_hdmf_attrs=True) def test_2d_data(self): data = np.array([[1, 0, 2], [0, 0, 3], [4, 5, 6]]) csr_mat = CSRMatrix(data) sps_mat = sps.csr_matrix(data) np.testing.assert_array_equal(csr_mat.data, sps_mat.data) def test_getattrs(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2], dtype=np.int32) indptr = np.array([0, 2, 3, 6], dtype=np.int32) shape = (3, 3) csr_mat = CSRMatrix(data, indices, indptr, shape) np.testing.assert_array_equal(data, csr_mat.data) np.testing.assert_array_equal(indices, csr_mat.indices) np.testing.assert_array_equal(indptr, csr_mat.indptr) np.testing.assert_array_equal(shape, csr_mat.shape) self.assertEqual(csr_mat.indices.dtype.type, np.uint32) self.assertEqual(csr_mat.indptr.dtype.type, np.uint32) # NOTE: shape is stored internally in scipy.sparse.spmat as a tuple of ints. this is then converted to ndarray # but precision differs by OS self.assertTrue(np.issubdtype(csr_mat.shape.dtype.type, np.unsignedinteger)) def test_to_spmat(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) shape = (3, 3) csr_mat = CSRMatrix(data, indices, indptr, shape) spmat_array = csr_mat.to_spmat().toarray() expected = np.asarray([[1, 0, 2], [0, 0, 3], [4, 5, 6]]) np.testing.assert_array_equal(spmat_array, expected) def test_constructor_indices_missing(self): data = np.array([1, 2, 3, 4, 5, 6]) msg = "Must specify 'indptr', 'indices', and 'shape' arguments when passing data array." with self.assertRaisesWith(ValueError, msg): CSRMatrix(data) def test_constructor_bad_indices(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, -2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) shape = (3, 3) msg = "Cannot convert 'indices' to an array of unsigned integers." with self.assertRaisesWith(ValueError, msg): CSRMatrix(data, indices, indptr, shape) def test_constructor_bad_indices_dim(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([[0, 2, 2, 0, 1, 2]]) indptr = np.array([0, 2, 3, 6]) shape = (3, 3) msg = "'indices' must be a 1D array of unsigned integers." with self.assertRaisesWith(ValueError, msg): CSRMatrix(data, indices, indptr, shape) def test_constructor_bad_shape(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) shape = (3, ) msg = "'shape' argument must specify two and only two dimensions." with self.assertRaisesWith(ValueError, msg): CSRMatrix(data, indices, indptr, shape) def test_array_bad_dim(self): data = np.array([[[1, 2], [3, 4], [5, 6]]]) indices = np.array([0, 2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) msg = "'data' argument cannot be ndarray of dimensionality > 2." with self.assertRaisesWith(ValueError, msg): CSRMatrix(data, indices, indptr, (3, 3)) class TestCSRMatrixRoundTrip(H5RoundTripMixin, TestCase): def setUpContainer(self): data = np.array([1, 2, 3, 4, 5, 6]) indices = np.array([0, 2, 2, 0, 1, 2]) indptr = np.array([0, 2, 3, 6]) return CSRMatrix(data, indices, indptr, (3, 3)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/common/test_table.py0000644000655200065520000031067100000000000020702 0ustar00circlecicirclecifrom collections import OrderedDict import h5py import numpy as np import os import pandas as pd import unittest from hdmf import Container from hdmf.backends.hdf5 import H5DataIO, HDF5IO from hdmf.backends.hdf5.h5tools import H5_TEXT, H5PY_3 from hdmf.common import (DynamicTable, VectorData, VectorIndex, ElementIdentifiers, EnumData, DynamicTableRegion, get_manager, SimpleMultiContainer) from hdmf.testing import TestCase, H5RoundTripMixin, remove_test_file from hdmf.utils import StrDataset from tests.unit.utils import get_temp_filepath class TestDynamicTable(TestCase): def setUp(self): self.spec = [ {'name': 'foo', 'description': 'foo column'}, {'name': 'bar', 'description': 'bar column'}, {'name': 'baz', 'description': 'baz column'}, ] self.data = [ [1, 2, 3, 4, 5], [10.0, 20.0, 30.0, 40.0, 50.0], ['cat', 'dog', 'bird', 'fish', 'lizard'] ] def with_table_columns(self): cols = [VectorData(**d) for d in self.spec] table = DynamicTable("with_table_columns", 'a test table', columns=cols) return table def with_columns_and_data(self): columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data) ] return DynamicTable("with_columns_and_data", 'a test table', columns=columns) def with_spec(self): table = DynamicTable("with_spec", 'a test table', columns=self.spec) return table def check_empty_table(self, table): self.assertIsInstance(table.columns, tuple) self.assertIsInstance(table.columns[0], VectorData) self.assertEqual(len(table.columns), 3) self.assertTupleEqual(table.colnames, ('foo', 'bar', 'baz')) def test_constructor_table_columns(self): table = self.with_table_columns() self.assertEqual(table.name, 'with_table_columns') self.check_empty_table(table) def test_constructor_spec(self): table = self.with_spec() self.assertEqual(table.name, 'with_spec') self.check_empty_table(table) def check_table(self, table): self.assertEqual(len(table), 5) self.assertEqual(table.columns[0].data, [1, 2, 3, 4, 5]) self.assertEqual(table.columns[1].data, [10.0, 20.0, 30.0, 40.0, 50.0]) self.assertEqual(table.columns[2].data, ['cat', 'dog', 'bird', 'fish', 'lizard']) self.assertEqual(table.id.data, [0, 1, 2, 3, 4]) self.assertTrue(hasattr(table, 'baz')) def test_constructor_ids_default(self): columns = [VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data)] table = DynamicTable("with_spec", 'a test table', columns=columns) self.check_table(table) def test_constructor_ids(self): columns = [VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data)] table = DynamicTable("with_columns", 'a test table', id=[0, 1, 2, 3, 4], columns=columns) self.check_table(table) def test_constructor_ElementIdentifier_ids(self): columns = [VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data)] ids = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) table = DynamicTable("with_columns", 'a test table', id=ids, columns=columns) self.check_table(table) def test_constructor_ids_bad_ids(self): columns = [VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data)] msg = "must provide same number of ids as length of columns" with self.assertRaisesWith(ValueError, msg): DynamicTable("with_columns", 'a test table', id=[0, 1], columns=columns) def test_constructor_bad_columns(self): columns = ['bad_column'] msg = "'columns' must be a list of dict, VectorData, DynamicTableRegion, or VectorIndex" with self.assertRaisesWith(ValueError, msg): DynamicTable("with_columns", 'a test table', columns=columns) def test_constructor_unequal_length_columns(self): columns = [VectorData(name='col1', description='desc', data=[1, 2, 3]), VectorData(name='col2', description='desc', data=[1, 2])] msg = "columns must be the same length" with self.assertRaisesWith(ValueError, msg): DynamicTable("with_columns", 'a test table', columns=columns) def test_constructor_colnames(self): """Test that passing colnames correctly sets the order of the columns.""" cols = [VectorData(**d) for d in self.spec] table = DynamicTable("with_columns", 'a test table', columns=cols, colnames=['baz', 'bar', 'foo']) self.assertTupleEqual(table.columns, tuple(cols[::-1])) def test_constructor_colnames_no_columns(self): """Test that passing colnames without columns raises an error.""" msg = "Must supply 'columns' if specifying 'colnames'" with self.assertRaisesWith(ValueError, msg): DynamicTable("with_columns", 'a test table', colnames=['baz', 'bar', 'foo']) def test_constructor_colnames_vectorindex(self): """Test that passing colnames with a VectorIndex column puts the index in the right location in columns.""" cols = [VectorData(**d) for d in self.spec] ind = VectorIndex(name='foo_index', data=list(), target=cols[0]) cols.append(ind) table = DynamicTable("with_columns", 'a test table', columns=cols, colnames=['baz', 'bar', 'foo']) self.assertTupleEqual(table.columns, (cols[2], cols[1], ind, cols[0])) def test_constructor_colnames_vectorindex_rev(self): """Test that passing colnames with a VectorIndex column puts the index in the right location in columns.""" cols = [VectorData(**d) for d in self.spec] ind = VectorIndex(name='foo_index', data=list(), target=cols[0]) cols.insert(0, ind) # put index before its target table = DynamicTable("with_columns", 'a test table', columns=cols, colnames=['baz', 'bar', 'foo']) self.assertTupleEqual(table.columns, (cols[3], cols[2], ind, cols[1])) def test_constructor_dup_index(self): """Test that passing two indices for the same column raises an error.""" cols = [VectorData(**d) for d in self.spec] cols.append(VectorIndex(name='foo_index', data=list(), target=cols[0])) cols.append(VectorIndex(name='foo_index2', data=list(), target=cols[0])) msg = "'columns' contains index columns with the same target: ['foo', 'foo']" with self.assertRaisesWith(ValueError, msg): DynamicTable("with_columns", 'a test table', columns=cols) def test_constructor_index_missing_target(self): """Test that passing an index without its target raises an error.""" cols = [VectorData(**d) for d in self.spec] missing_col = cols.pop(2) cols.append(VectorIndex(name='foo_index', data=list(), target=missing_col)) msg = "Found VectorIndex 'foo_index' but not its target 'baz'" with self.assertRaisesWith(ValueError, msg): DynamicTable("with_columns", 'a test table', columns=cols) def add_rows(self, table): table.add_row({'foo': 1, 'bar': 10.0, 'baz': 'cat'}) table.add_row({'foo': 2, 'bar': 20.0, 'baz': 'dog'}) table.add_row({'foo': 3, 'bar': 30.0, 'baz': 'bird'}) table.add_row({'foo': 4, 'bar': 40.0, 'baz': 'fish'}) table.add_row({'foo': 5, 'bar': 50.0, 'baz': 'lizard'}) def test_add_row(self): table = self.with_spec() self.add_rows(table) self.check_table(table) def test_get(self): table = self.with_spec() self.add_rows(table) self.assertIsInstance(table.get('foo'), VectorData) self.assertEqual(table.get('foo'), table['foo']) def test_get_not_found(self): table = self.with_spec() self.add_rows(table) self.assertIsNone(table.get('qux')) def test_get_not_found_default(self): table = self.with_spec() self.add_rows(table) self.assertEqual(table.get('qux', 1), 1) def test_get_item(self): table = self.with_spec() self.add_rows(table) self.check_table(table) def test_add_column(self): table = self.with_spec() table.add_column(name='qux', description='qux column') self.assertTupleEqual(table.colnames, ('foo', 'bar', 'baz', 'qux')) self.assertTrue(hasattr(table, 'qux')) def test_add_column_twice(self): table = self.with_spec() table.add_column(name='qux', description='qux column') msg = "column 'qux' already exists in DynamicTable 'with_spec'" with self.assertRaisesWith(ValueError, msg): table.add_column(name='qux', description='qux column') def test_add_column_vectorindex(self): table = self.with_spec() table.add_column(name='qux', description='qux column') ind = VectorIndex(name='bar', data=list(), target=table['qux']) msg = ("Passing a VectorIndex in for index may lead to unexpected behavior. This functionality will be " "deprecated in a future version of HDMF.") with self.assertWarnsWith(FutureWarning, msg): table.add_column(name='bad', description='bad column', index=ind) def test_add_column_multi_index(self): table = self.with_spec() table.add_column(name='qux', description='qux column', index=2) table.add_row(foo=5, bar=50.0, baz='lizard', qux=[ [1, 2, 3], [1, 2, 3, 4] ]) table.add_row(foo=5, bar=50.0, baz='lizard', qux=[ [1, 2] ] ) def test_auto_multi_index_required(self): class TestTable(DynamicTable): __columns__ = (dict(name='qux', description='qux column', index=3, required=True),) table = TestTable('table_name', 'table_description') self.assertIsInstance(table.qux, VectorData) # check that the attribute is set self.assertIsInstance(table.qux_index, VectorIndex) # check that the attribute is set self.assertIsInstance(table.qux_index_index, VectorIndex) # check that the attribute is set self.assertIsInstance(table.qux_index_index_index, VectorIndex) # check that the attribute is set table.add_row( qux=[ [ [1, 2, 3], [1, 2, 3, 4] ] ] ) table.add_row( qux=[ [ [1, 2] ] ] ) expected = [ [ [ [1, 2, 3], [1, 2, 3, 4] ] ], [ [ [1, 2] ] ] ] self.assertListEqual(table['qux'][:], expected) self.assertEqual(table.qux_index_index_index.data, [1, 2]) def test_auto_multi_index(self): class TestTable(DynamicTable): __columns__ = (dict(name='qux', description='qux column', index=3),) # this is optional table = TestTable('table_name', 'table_description') self.assertIsNone(table.qux) # these are reserved as attributes but not yet initialized self.assertIsNone(table.qux_index) self.assertIsNone(table.qux_index_index) self.assertIsNone(table.qux_index_index_index) table.add_row( qux=[ [ [1, 2, 3], [1, 2, 3, 4] ] ] ) table.add_row( qux=[ [ [1, 2] ] ] ) expected = [ [ [ [1, 2, 3], [1, 2, 3, 4] ] ], [ [ [1, 2] ] ] ] self.assertListEqual(table['qux'][:], expected) self.assertEqual(table.qux_index_index_index.data, [1, 2]) def test_getitem_row_num(self): table = self.with_spec() self.add_rows(table) row = table[2] self.assertTupleEqual(row.shape, (1, 3)) self.assertTupleEqual(tuple(row.iloc[0]), (3, 30.0, 'bird')) def test_getitem_row_slice(self): table = self.with_spec() self.add_rows(table) rows = table[1:3] self.assertIsInstance(rows, pd.DataFrame) self.assertTupleEqual(rows.shape, (2, 3)) self.assertTupleEqual(tuple(rows.iloc[1]), (3, 30.0, 'bird')) def test_getitem_row_slice_with_step(self): table = self.with_spec() self.add_rows(table) rows = table[0:5:2] self.assertIsInstance(rows, pd.DataFrame) self.assertTupleEqual(rows.shape, (3, 3)) self.assertEqual(rows.iloc[2][0], 5) self.assertEqual(rows.iloc[2][1], 50.0) self.assertEqual(rows.iloc[2][2], 'lizard') def test_getitem_invalid_keytype(self): table = self.with_spec() self.add_rows(table) with self.assertRaises(KeyError): _ = table[0.1] def test_getitem_col_select_and_row_slice(self): table = self.with_spec() self.add_rows(table) col = table[1:3, 'bar'] self.assertEqual(len(col), 2) self.assertEqual(col[0], 20.0) self.assertEqual(col[1], 30.0) def test_getitem_column(self): table = self.with_spec() self.add_rows(table) col = table['bar'] self.assertEqual(col[0], 10.0) self.assertEqual(col[1], 20.0) self.assertEqual(col[2], 30.0) self.assertEqual(col[3], 40.0) self.assertEqual(col[4], 50.0) def test_getitem_list_idx(self): table = self.with_spec() self.add_rows(table) row = table[[0, 2, 4]] self.assertEqual(len(row), 3) self.assertTupleEqual(tuple(row.iloc[0]), (1, 10.0, 'cat')) self.assertTupleEqual(tuple(row.iloc[1]), (3, 30.0, 'bird')) self.assertTupleEqual(tuple(row.iloc[2]), (5, 50.0, 'lizard')) def test_getitem_point_idx_colname(self): table = self.with_spec() self.add_rows(table) val = table[2, 'bar'] self.assertEqual(val, 30.0) def test_getitem_point_idx(self): table = self.with_spec() self.add_rows(table) row = table[2] self.assertTupleEqual(tuple(row.iloc[0]), (3, 30.0, 'bird')) def test_getitem_point_idx_colidx(self): table = self.with_spec() self.add_rows(table) val = table[2, 2] self.assertEqual(val, 30.0) def test_pandas_roundtrip(self): df = pd.DataFrame({ 'a': [1, 2, 3, 4], 'b': ['a', 'b', 'c', '4'] }, index=pd.Index(name='an_index', data=[2, 4, 6, 8])) table = DynamicTable.from_dataframe(df, 'foo') obtained = table.to_dataframe() self.assertTrue(df.equals(obtained)) def test_to_dataframe(self): table = self.with_columns_and_data() data = OrderedDict() for name in table.colnames: if name == 'foo': data[name] = [1, 2, 3, 4, 5] elif name == 'bar': data[name] = [10.0, 20.0, 30.0, 40.0, 50.0] elif name == 'baz': data[name] = ['cat', 'dog', 'bird', 'fish', 'lizard'] expected_df = pd.DataFrame(data) obtained_df = table.to_dataframe() self.assertTrue(expected_df.equals(obtained_df)) def test_from_dataframe(self): df = pd.DataFrame({ 'foo': [1, 2, 3, 4, 5], 'bar': [10.0, 20.0, 30.0, 40.0, 50.0], 'baz': ['cat', 'dog', 'bird', 'fish', 'lizard'] }).loc[:, ('foo', 'bar', 'baz')] obtained_table = DynamicTable.from_dataframe(df, 'test') self.check_table(obtained_table) def test_from_dataframe_eq(self): expected = DynamicTable('test_table', 'the expected table') expected.add_column('a', '2d column') expected.add_column('b', '1d column') expected.add_row(a=[1, 2, 3], b='4') expected.add_row(a=[1, 2, 3], b='5') expected.add_row(a=[1, 2, 3], b='6') df = pd.DataFrame({ 'a': [[1, 2, 3], [1, 2, 3], [1, 2, 3]], 'b': ['4', '5', '6'] }) coldesc = {'a': '2d column', 'b': '1d column'} received = DynamicTable.from_dataframe(df, 'test_table', table_description='the expected table', column_descriptions=coldesc) self.assertContainerEqual(expected, received, ignore_hdmf_attrs=True) def test_from_dataframe_dup_attr(self): """ Test that when a DynamicTable is generated from a dataframe where one of the column names is an existing DynamicTable attribute (e.g., description), that the table can be created, the existing attribute is not altered, a warning is raised, and the column can still be accessed using the table[col_name] syntax. """ df = pd.DataFrame({ 'parent': [1, 2, 3, 4, 5], 'name': [10.0, 20.0, 30.0, 40.0, 50.0], 'description': ['cat', 'dog', 'bird', 'fish', 'lizard'] }) # technically there are three separate warnings but just catch one here msg1 = ("An attribute 'parent' already exists on DynamicTable 'test' so this column cannot be accessed " "as an attribute, e.g., table.parent; it can only be accessed using other methods, e.g., " "table['parent'].") with self.assertWarnsWith(UserWarning, msg1): table = DynamicTable.from_dataframe(df, 'test') self.assertEqual(table.name, 'test') self.assertEqual(table.description, '') self.assertIsNone(table.parent) self.assertEqual(table['name'].name, 'name') self.assertEqual(table['description'].name, 'description') self.assertEqual(table['parent'].name, 'parent') def test_missing_columns(self): table = self.with_spec() with self.assertRaises(ValueError): table.add_row({'bar': 60.0, 'foo': [6]}, None) def test_enforce_unique_id_error(self): table = self.with_spec() table.add_row(id=10, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) with self.assertRaises(ValueError): table.add_row(id=10, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) def test_not_enforce_unique_id_error(self): table = self.with_spec() table.add_row(id=10, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=False) try: table.add_row(id=10, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=False) except ValueError as e: self.fail("add row with non unique id raised error %s" % str(e)) def test_bad_id_type_error(self): table = self.with_spec() with self.assertRaises(TypeError): table.add_row(id=10.1, data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) with self.assertRaises(TypeError): table.add_row(id='str', data={'foo': 1, 'bar': 10.0, 'baz': 'cat'}, enforce_unique_id=True) def test_extra_columns(self): table = self.with_spec() with self.assertRaises(ValueError): table.add_row({'bar': 60.0, 'foo': 6, 'baz': 'oryx', 'qax': -1}, None) def test_nd_array_to_df(self): data = np.array([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) col = VectorData(name='data', description='desc', data=data) df = DynamicTable('test', 'desc', np.arange(3, dtype='int'), (col, )).to_dataframe() df2 = pd.DataFrame({'data': [x for x in data]}, index=pd.Index(name='id', data=[0, 1, 2])) pd.testing.assert_frame_equal(df, df2) def test_id_search(self): table = self.with_spec() data = [{'foo': 1, 'bar': 10.0, 'baz': 'cat'}, {'foo': 2, 'bar': 20.0, 'baz': 'dog'}, {'foo': 3, 'bar': 30.0, 'baz': 'bird'}, # id=2 {'foo': 4, 'bar': 40.0, 'baz': 'fish'}, {'foo': 5, 'bar': 50.0, 'baz': 'lizard'} # id=4 ] for i in data: table.add_row(i) res = table[table.id == [2, 4]] self.assertEqual(len(res), 2) self.assertTupleEqual(tuple(res.iloc[0]), (3, 30.0, 'bird')) self.assertTupleEqual(tuple(res.iloc[1]), (5, 50.0, 'lizard')) def test_repr(self): table = self.with_spec() expected = """with_spec hdmf.common.table.DynamicTable at 0x%d Fields: colnames: ['foo' 'bar' 'baz'] columns: ( foo , bar , baz ) description: a test table """ expected = expected % id(table) self.assertEqual(str(table), expected) def test_add_column_existing_attr(self): table = self.with_table_columns() attrs = ['name', 'description', 'parent', 'id', 'fields'] # just a few for attr in attrs: with self.subTest(attr=attr): msg = ("An attribute '%s' already exists on DynamicTable 'with_table_columns' so this column cannot be " "accessed as an attribute, e.g., table.%s; it can only be accessed using other methods, " "e.g., table['%s']." % (attr, attr, attr)) with self.assertWarnsWith(UserWarning, msg): table.add_column(name=attr, description='') def test_init_columns_existing_attr(self): attrs = ['name', 'description', 'parent', 'id', 'fields'] # just a few for attr in attrs: with self.subTest(attr=attr): cols = [VectorData(name=attr, description='')] msg = ("An attribute '%s' already exists on DynamicTable 'test_table' so this column cannot be " "accessed as an attribute, e.g., table.%s; it can only be accessed using other methods, " "e.g., table['%s']." % (attr, attr, attr)) with self.assertWarnsWith(UserWarning, msg): DynamicTable("test_table", 'a test table', columns=cols) def test_colnames_none(self): table = DynamicTable('table0', 'an example table') self.assertTupleEqual(table.colnames, tuple()) self.assertTupleEqual(table.columns, tuple()) def test_index_out_of_bounds(self): table = self.with_columns_and_data() msg = "Row index out of range for DynamicTable 'with_columns_and_data' (length 5)." with self.assertRaisesWith(IndexError, msg): table[5] def test_no_df_nested(self): table = self.with_columns_and_data() msg = 'DynamicTable.get() with df=False and index=False is not yet supported.' with self.assertRaisesWith(ValueError, msg): table.get(0, df=False, index=False) def test_multidim_col(self): multidim_data = [ [[1, 2], [3, 4], [5, 6]], ((1, 2), (3, 4), (5, 6)), [(1, 'a', True), (2, 'b', False), (3, 'c', True)], ] columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, multidim_data) ] table = DynamicTable("with_columns_and_data", 'a test table', columns=columns) df = table.to_dataframe() df2 = pd.DataFrame({'foo': multidim_data[0], 'bar': multidim_data[1], 'baz': multidim_data[2]}, index=pd.Index(name='id', data=[0, 1, 2])) pd.testing.assert_frame_equal(df, df2) df3 = pd.DataFrame({'foo': [multidim_data[0][0]], 'bar': [multidim_data[1][0]], 'baz': [multidim_data[2][0]]}, index=pd.Index(name='id', data=[0])) pd.testing.assert_frame_equal(table.get(0), df3) def test_multidim_col_one_elt_list(self): data = [[1, 2]] col = VectorData(name='data', description='desc', data=data) table = DynamicTable('test', 'desc', columns=(col, )) df = table.to_dataframe() df2 = pd.DataFrame({'data': [x for x in data]}, index=pd.Index(name='id', data=[0])) pd.testing.assert_frame_equal(df, df2) pd.testing.assert_frame_equal(table.get(0), df2) def test_multidim_col_one_elt_tuple(self): data = [(1, 2)] col = VectorData(name='data', description='desc', data=data) table = DynamicTable('test', 'desc', columns=(col, )) df = table.to_dataframe() df2 = pd.DataFrame({'data': [x for x in data]}, index=pd.Index(name='id', data=[0])) pd.testing.assert_frame_equal(df, df2) pd.testing.assert_frame_equal(table.get(0), df2) def test_eq(self): columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data) ] test_table = DynamicTable("with_columns_and_data", 'a test table', columns=columns) table = self.with_columns_and_data() self.assertTrue(table == test_table) def test_eq_from_df(self): df = pd.DataFrame({ 'foo': [1, 2, 3, 4, 5], 'bar': [10.0, 20.0, 30.0, 40.0, 50.0], 'baz': ['cat', 'dog', 'bird', 'fish', 'lizard'] }).loc[:, ('foo', 'bar', 'baz')] test_table = DynamicTable.from_dataframe(df, 'with_columns_and_data', table_description='a test table') table = self.with_columns_and_data() self.assertTrue(table == test_table) def test_eq_diff_missing_col(self): columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data) ] del columns[-1] test_table = DynamicTable("with_columns_and_data", 'a test table', columns=columns) table = self.with_columns_and_data() self.assertFalse(table == test_table) def test_eq_diff_name(self): columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data) ] test_table = DynamicTable("wrong name", 'a test table', columns=columns) table = self.with_columns_and_data() self.assertFalse(table == test_table) def test_eq_diff_desc(self): columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data) ] test_table = DynamicTable("with_columns_and_data", 'wrong description', columns=columns) table = self.with_columns_and_data() self.assertFalse(table == test_table) def test_eq_bad_type(self): container = Container('test_container') table = self.with_columns_and_data() self.assertFalse(table == container) class TestDynamicTableRoundTrip(H5RoundTripMixin, TestCase): def setUpContainer(self): table = DynamicTable('table0', 'an example table') table.add_column('foo', 'an int column') table.add_column('bar', 'a float column') table.add_column('baz', 'a string column') table.add_column('qux', 'a boolean column') table.add_column('corge', 'a doubly indexed int column', index=2) table.add_column('quux', 'an enum column', enum=True) table.add_row(foo=27, bar=28.0, baz="cat", corge=[[1, 2, 3], [4, 5, 6]], qux=True, quux='a') table.add_row(foo=37, bar=38.0, baz="dog", corge=[[11, 12, 13], [14, 15, 16]], qux=False, quux='b') return table def test_index_out_of_bounds(self): table = self.roundtripContainer() msg = "Row index 5 out of range for DynamicTable 'root' (length 2)." with self.assertRaisesWith(IndexError, msg): table[5] class TestEmptyDynamicTableRoundTrip(H5RoundTripMixin, TestCase): """Test roundtripping a DynamicTable with no rows and no columns.""" def setUpContainer(self): table = DynamicTable('table0', 'an example table') return table class TestDynamicTableRegion(TestCase): def setUp(self): self.spec = [ {'name': 'foo', 'description': 'foo column'}, {'name': 'bar', 'description': 'bar column'}, {'name': 'baz', 'description': 'baz column'}, ] self.data = [ [1, 2, 3, 4, 5], [10.0, 20.0, 30.0, 40.0, 50.0], ['cat', 'dog', 'bird', 'fish', 'lizard'] ] def with_columns_and_data(self): columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec, self.data) ] return DynamicTable("with_columns_and_data", 'a test table', columns=columns) def test_indexed_dynamic_table_region(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [1, 2, 2], 'desc', table=table) fetch_ids = dynamic_table_region[:3].index.values self.assertListEqual(fetch_ids.tolist(), [1, 2, 2]) def test_dynamic_table_region_iteration(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 3, 4], 'desc', table=table) for ii, item in enumerate(dynamic_table_region): self.assertTrue(table[ii].equals(item)) def test_dynamic_table_region_shape(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 3, 4], 'desc', table=table) self.assertTupleEqual(dynamic_table_region.shape, (5, 3)) def test_dynamic_table_region_to_dataframe(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) res = dynamic_table_region.to_dataframe() self.assertListEqual(res.index.tolist(), [0, 1, 2, 2]) self.assertListEqual(res['foo'].tolist(), [1, 2, 3, 3]) self.assertListEqual(res['bar'].tolist(), [10.0, 20.0, 30.0, 30.0]) self.assertListEqual(res['baz'].tolist(), ['cat', 'dog', 'bird', 'bird']) def test_dynamic_table_region_to_dataframe_exclude_cols(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) res = dynamic_table_region.to_dataframe(exclude={'baz', 'foo'}) self.assertListEqual(res.index.tolist(), [0, 1, 2, 2]) self.assertEqual(len(res.columns), 1) self.assertListEqual(res['bar'].tolist(), [10.0, 20.0, 30.0, 30.0]) def test_dynamic_table_region_getitem_slice(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) res = dynamic_table_region[1:3] self.assertListEqual(res.index.tolist(), [1, 2]) self.assertListEqual(res['foo'].tolist(), [2, 3]) self.assertListEqual(res['bar'].tolist(), [20.0, 30.0]) self.assertListEqual(res['baz'].tolist(), ['dog', 'bird']) def test_dynamic_table_region_getitem_single_row_by_index(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) res = dynamic_table_region[2] self.assertListEqual(res.index.tolist(), [2, ]) self.assertListEqual(res['foo'].tolist(), [3, ]) self.assertListEqual(res['bar'].tolist(), [30.0, ]) self.assertListEqual(res['baz'].tolist(), ['bird', ]) def test_dynamic_table_region_getitem_single_cell(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) res = dynamic_table_region[2, 'foo'] self.assertEqual(res, 3) res = dynamic_table_region[1, 'baz'] self.assertEqual(res, 'dog') def test_dynamic_table_region_getitem_slice_of_column(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) res = dynamic_table_region[0:3, 'foo'] self.assertListEqual(res, [1, 2, 3]) res = dynamic_table_region[1:3, 'baz'] self.assertListEqual(res, ['dog', 'bird']) def test_dynamic_table_region_getitem_bad_index(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) with self.assertRaises(ValueError): _ = dynamic_table_region[True] def test_dynamic_table_region_table_prop(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) self.assertEqual(table, dynamic_table_region.table) def test_dynamic_table_region_set_table_prop(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc') dynamic_table_region.table = table self.assertEqual(table, dynamic_table_region.table) def test_dynamic_table_region_set_table_prop_to_none(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) try: dynamic_table_region.table = None except AttributeError: self.fail("DynamicTableRegion table setter raised AttributeError unexpectedly!") @unittest.skip('we no longer check data contents for performance reasons') def test_dynamic_table_region_set_with_bad_data(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [5, 1], 'desc') # index 5 is out of range with self.assertRaises(IndexError): dynamic_table_region.table = table self.assertIsNone(dynamic_table_region.table) def test_repr(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [1, 2, 2], 'desc', table=table) expected = """dtr hdmf.common.table.DynamicTableRegion at 0x%d Target table: with_columns_and_data hdmf.common.table.DynamicTable at 0x%d """ expected = expected % (id(dynamic_table_region), id(table)) self.assertEqual(str(dynamic_table_region), expected) def test_no_df_nested(self): table = self.with_columns_and_data() dynamic_table_region = DynamicTableRegion('dtr', [0, 1, 2, 2], 'desc', table=table) msg = 'DynamicTableRegion.get() with df=False and index=False is not yet supported.' with self.assertRaisesWith(ValueError, msg): dynamic_table_region.get(0, df=False, index=False) class DynamicTableRegionRoundTrip(H5RoundTripMixin, TestCase): def make_tables(self): self.spec2 = [ {'name': 'qux', 'description': 'qux column'}, {'name': 'quz', 'description': 'quz column'}, ] self.data2 = [ ['qux_1', 'qux_2'], ['quz_1', 'quz_2'], ] target_columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec2, self.data2) ] target_table = DynamicTable("target_table", 'example table to target with a DynamicTableRegion', columns=target_columns) self.spec1 = [ {'name': 'foo', 'description': 'foo column'}, {'name': 'bar', 'description': 'bar column'}, {'name': 'baz', 'description': 'baz column'}, {'name': 'dtr', 'description': 'DTR'}, ] self.data1 = [ [1, 2, 3, 4, 5], [10.0, 20.0, 30.0, 40.0, 50.0], ['cat', 'dog', 'bird', 'fish', 'lizard'] ] columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(self.spec1, self.data1) ] columns.append(DynamicTableRegion(name='dtr', description='example DynamicTableRegion', data=[0, 1, 1, 0, 1], table=target_table)) table = DynamicTable("table_with_dtr", 'a test table that has a DynamicTableRegion', columns=columns) return table, target_table def setUp(self): self.table, self.target_table = self.make_tables() super().setUp() def setUpContainer(self): multi_container = SimpleMultiContainer('multi', [self.table, self.target_table]) return multi_container def _get(self, arg): mc = self.roundtripContainer() table = mc.containers['table_with_dtr'] return table.get(arg) def _get_nested(self, arg): mc = self.roundtripContainer() table = mc.containers['table_with_dtr'] return table.get(arg, index=False) def _get_nodf(self, arg): mc = self.roundtripContainer() table = mc.containers['table_with_dtr'] return table.get(arg, df=False) def _getitem(self, arg): mc = self.roundtripContainer() table = mc.containers['table_with_dtr'] return table[arg] def test_getitem_oor(self): msg = 'Row index 12 out of range for DynamicTable \'table_with_dtr\' (length 5).' with self.assertRaisesWith(IndexError, msg): self._getitem(12) def test_getitem_badcol(self): with self.assertRaisesWith(KeyError, '\'boo\''): self._getitem('boo') def _assert_two_elem_df(self, rec): columns = ['foo', 'bar', 'baz', 'dtr'] data = [[1, 10.0, 'cat', 0], [2, 20.0, 'dog', 1]] exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name='id', data=[0, 1])) pd.testing.assert_frame_equal(rec, exp, check_dtype=False) def _assert_one_elem_df(self, rec): columns = ['foo', 'bar', 'baz', 'dtr'] data = [[1, 10.0, 'cat', 0]] exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name='id', data=[0])) pd.testing.assert_frame_equal(rec, exp, check_dtype=False) def _assert_two_elem_df_nested(self, rec): nested_columns = ['qux', 'quz'] nested_data = [['qux_1', 'quz_1'], ['qux_2', 'quz_2']] nested_df = pd.DataFrame(data=nested_data, columns=nested_columns, index=pd.Series(name='id', data=[0, 1])) columns = ['foo', 'bar', 'baz'] data = [[1, 10.0, 'cat'], [2, 20.0, 'dog']] exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name='id', data=[0, 1])) # remove nested dataframe and test each df separately pd.testing.assert_frame_equal(rec['dtr'][0], nested_df.iloc[[0]]) pd.testing.assert_frame_equal(rec['dtr'][1], nested_df.iloc[[1]]) del rec['dtr'] pd.testing.assert_frame_equal(rec, exp, check_dtype=False) def _assert_one_elem_df_nested(self, rec): nested_columns = ['qux', 'quz'] nested_data = [['qux_1', 'quz_1'], ['qux_2', 'quz_2']] nested_df = pd.DataFrame(data=nested_data, columns=nested_columns, index=pd.Series(name='id', data=[0, 1])) columns = ['foo', 'bar', 'baz'] data = [[1, 10.0, 'cat']] exp = pd.DataFrame(data=data, columns=columns, index=pd.Series(name='id', data=[0])) # remove nested dataframe and test each df separately pd.testing.assert_frame_equal(rec['dtr'][0], nested_df.iloc[[0]]) del rec['dtr'] pd.testing.assert_frame_equal(rec, exp, check_dtype=False) ##################### # tests DynamicTableRegion.__getitem__ def test_getitem_int(self): rec = self._getitem(0) self._assert_one_elem_df(rec) def test_getitem_list_single(self): rec = self._getitem([0]) self._assert_one_elem_df(rec) def test_getitem_list(self): rec = self._getitem([0, 1]) self._assert_two_elem_df(rec) def test_getitem_slice(self): rec = self._getitem(slice(0, 2, None)) self._assert_two_elem_df(rec) ##################### # tests DynamicTableRegion.get, return a DataFrame def test_get_int(self): rec = self._get(0) self._assert_one_elem_df(rec) def test_get_list_single(self): rec = self._get([0]) self._assert_one_elem_df(rec) def test_get_list(self): rec = self._get([0, 1]) self._assert_two_elem_df(rec) def test_get_slice(self): rec = self._get(slice(0, 2, None)) self._assert_two_elem_df(rec) ##################### # tests DynamicTableRegion.get, return a DataFrame with nested DataFrame def test_get_nested_int(self): rec = self._get_nested(0) self._assert_one_elem_df_nested(rec) def test_get_nested_list_single(self): rec = self._get_nested([0]) self._assert_one_elem_df_nested(rec) def test_get_nested_list(self): rec = self._get_nested([0, 1]) self._assert_two_elem_df_nested(rec) def test_get_nested_slice(self): rec = self._get_nested(slice(0, 2, None)) self._assert_two_elem_df_nested(rec) ##################### # tests DynamicTableRegion.get, DO NOT return a DataFrame def test_get_nodf_int(self): rec = self._get_nodf(0) exp = [0, 1, 10.0, 'cat', 0] self.assertListEqual(rec, exp) def _assert_list_of_ndarray_equal(self, l1, l2): """ This is a helper function for test_get_nodf_list and test_get_nodf_slice. It compares ndarrays from a list of ndarrays """ for a1, a2 in zip(l1, l2): if isinstance(a1, list): self._assert_list_of_ndarray_equal(a1, a2) else: np.testing.assert_array_equal(a1, a2) def test_get_nodf_list_single(self): rec = self._get_nodf([0]) exp = [np.array([0]), np.array([1]), np.array([10.0]), np.array(['cat']), np.array([0])] self._assert_list_of_ndarray_equal(exp, rec) def test_get_nodf_list(self): rec = self._get_nodf([0, 1]) exp = [np.array([0, 1]), np.array([1, 2]), np.array([10.0, 20.0]), np.array(['cat', 'dog']), np.array([0, 1])] self._assert_list_of_ndarray_equal(exp, rec) def test_get_nodf_slice(self): rec = self._get_nodf(slice(0, 2, None)) exp = [np.array([0, 1]), np.array([1, 2]), np.array([10.0, 20.0]), np.array(['cat', 'dog']), np.array([0, 1])] self._assert_list_of_ndarray_equal(exp, rec) def test_getitem_int_str(self): """Test DynamicTableRegion.__getitem__ with (int, str).""" mc = self.roundtripContainer() table = mc.containers['table_with_dtr'] rec = table['dtr'][0, 'qux'] self.assertEqual(rec, 'qux_1') def test_getitem_str(self): """Test DynamicTableRegion.__getitem__ with str.""" mc = self.roundtripContainer() table = mc.containers['table_with_dtr'] rec = table['dtr']['qux'] self.assertIs(rec, mc.containers['target_table']['qux']) class TestElementIdentifiers(TestCase): def test_identifier_search_single_list(self): e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) a = (e == [1]) np.testing.assert_array_equal(a, [1]) def test_identifier_search_single_int(self): e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) a = (e == 2) np.testing.assert_array_equal(a, [2]) def test_identifier_search_single_list_not_found(self): e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) a = (e == [10]) np.testing.assert_array_equal(a, []) def test_identifier_search_single_int_not_found(self): e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) a = (e == 10) np.testing.assert_array_equal(a, []) def test_identifier_search_single_list_all_match(self): e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) a = (e == [1, 2, 3]) np.testing.assert_array_equal(a, [1, 2, 3]) def test_identifier_search_single_list_partial_match(self): e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) a = (e == [1, 2, 10]) np.testing.assert_array_equal(a, [1, 2]) a = (e == [-1, 2, 10]) np.testing.assert_array_equal(a, [2, ]) def test_identifier_search_with_element_identifier(self): e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) a = (e == ElementIdentifiers('ids', [1, 2, 10])) np.testing.assert_array_equal(a, [1, 2]) def test_identifier_search_with_bad_ids(self): e = ElementIdentifiers('ids', [0, 1, 2, 3, 4]) with self.assertRaises(TypeError): _ = (e == 0.1) with self.assertRaises(TypeError): _ = (e == 'test') class SubTable(DynamicTable): __columns__ = ( {'name': 'col1', 'description': 'required column', 'required': True}, {'name': 'col2', 'description': 'optional column'}, {'name': 'col3', 'description': 'required, indexed column', 'required': True, 'index': True}, {'name': 'col4', 'description': 'optional, indexed column', 'index': True}, {'name': 'col5', 'description': 'required region', 'required': True, 'table': True}, {'name': 'col6', 'description': 'optional region', 'table': True}, {'name': 'col7', 'description': 'required, indexed region', 'required': True, 'index': True, 'table': True}, {'name': 'col8', 'description': 'optional, indexed region', 'index': True, 'table': True}, {'name': 'col10', 'description': 'optional, indexed enum column', 'index': True, 'class': EnumData}, {'name': 'col11', 'description': 'optional, enumerable column', 'enum': True, 'index': True}, ) class SubSubTable(SubTable): __columns__ = ( {'name': 'col9', 'description': 'required column', 'required': True}, # TODO handle edge case where subclass re-defines a column from superclass # {'name': 'col2', 'description': 'optional column subsub', 'required': True}, # make col2 required ) class TestDynamicTableClassColumns(TestCase): """Test functionality related to the predefined __columns__ field of a DynamicTable class.""" def test_init(self): """Test that required columns, and not optional columns, in __columns__ are created on init.""" table = SubTable(name='subtable', description='subtable description') self.assertEqual(table.colnames, ('col1', 'col3', 'col5', 'col7')) # test different access methods. note: table.get('col1') is equivalent to table['col1'] self.assertEqual(table.col1.description, 'required column') self.assertEqual(table.col3.description, 'required, indexed column') self.assertEqual(table.col5.description, 'required region') self.assertEqual(table.col7.description, 'required, indexed region') self.assertEqual(table['col1'].description, 'required column') # self.assertEqual(table['col3'].description, 'required, indexed column') # TODO this should work self.assertIsNone(table.col2) self.assertIsNone(table.col4) self.assertIsNone(table.col4_index) self.assertIsNone(table.col6) self.assertIsNone(table.col8) self.assertIsNone(table.col8_index) self.assertIsNone(table.col11) self.assertIsNone(table.col11_index) # uninitialized optional predefined columns cannot be accessed in this manner with self.assertRaisesWith(KeyError, "'col2'"): table['col2'] def test_gather_columns_inheritance(self): """Test that gathering columns across a type hierarchy works.""" table = SubSubTable(name='subtable', description='subtable description') self.assertEqual(table.colnames, ('col1', 'col3', 'col5', 'col7', 'col9')) def test_bad_predefined_columns(self): """Test that gathering columns across a type hierarchy works.""" msg = "'__columns__' must be of type tuple, found " with self.assertRaisesWith(TypeError, msg): class BadSubTable(DynamicTable): __columns__ = [] def test_add_req_column(self): """Test that adding a required column from __columns__ raises an error.""" table = SubTable(name='subtable', description='subtable description') msg = "column 'col1' already exists in SubTable 'subtable'" with self.assertRaisesWith(ValueError, msg): table.add_column(name='col1', description='column #1') def test_add_req_ind_column(self): """Test that adding a required, indexed column from __columns__ raises an error.""" table = SubTable(name='subtable', description='subtable description') msg = "column 'col3' already exists in SubTable 'subtable'" with self.assertRaisesWith(ValueError, msg): table.add_column(name='col3', description='column #3') def test_add_opt_column(self): """Test that adding an optional column from __columns__ with matching specs except for description works.""" table = SubTable(name='subtable', description='subtable description') table.add_column(name='col2', description='column #2') # override __columns__ description self.assertEqual(table.col2.description, 'column #2') table.add_column(name='col4', description='column #4', index=True) self.assertEqual(table.col4.description, 'column #4') table.add_column(name='col6', description='column #6', table=True) self.assertEqual(table.col6.description, 'column #6') table.add_column(name='col8', description='column #8', index=True, table=True) self.assertEqual(table.col8.description, 'column #8') table.add_column(name='col10', description='column #10', index=True, col_cls=EnumData) self.assertIsInstance(table.col10, EnumData) table.add_column(name='col11', description='column #11', enum=True, index=True) self.assertIsInstance(table.col11, EnumData) def test_add_opt_column_mismatched_table_true(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" table = SubTable(name='subtable', description='subtable description') msg = ("Column 'col2' is predefined in SubTable with table=False which does not match the entered table " "argument. The predefined table spec will be ignored. " "Please ensure the new column complies with the spec. " "This will raise an error in a future version of HDMF.") with self.assertWarnsWith(UserWarning, msg): table.add_column(name='col2', description='column #2', table=True) self.assertEqual(table.col2.description, 'column #2') self.assertEqual(type(table.col2), DynamicTableRegion) # not VectorData def test_add_opt_column_mismatched_table_table(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" table = SubTable(name='subtable', description='subtable description') msg = ("Column 'col2' is predefined in SubTable with table=False which does not match the entered table " "argument. The predefined table spec will be ignored. " "Please ensure the new column complies with the spec. " "This will raise an error in a future version of HDMF.") with self.assertWarnsWith(UserWarning, msg): table.add_column(name='col2', description='column #2', table=DynamicTable('dummy', 'dummy')) self.assertEqual(table.col2.description, 'column #2') self.assertEqual(type(table.col2), DynamicTableRegion) # not VectorData def test_add_opt_column_mismatched_index_true(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" table = SubTable(name='subtable', description='subtable description') msg = ("Column 'col2' is predefined in SubTable with index=False which does not match the entered index " "argument. The predefined index spec will be ignored. " "Please ensure the new column complies with the spec. " "This will raise an error in a future version of HDMF.") with self.assertWarnsWith(UserWarning, msg): table.add_column(name='col2', description='column #2', index=True) self.assertEqual(table.col2.description, 'column #2') self.assertEqual(type(table.get('col2')), VectorIndex) # not VectorData def test_add_opt_column_mismatched_index_data(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" table = SubTable(name='subtable', description='subtable description') table.add_row(col1='a', col3='c', col5='e', col7='g') table.add_row(col1='a', col3='c', col5='e', col7='g') msg = ("Column 'col2' is predefined in SubTable with index=False which does not match the entered index " "argument. The predefined index spec will be ignored. " "Please ensure the new column complies with the spec. " "This will raise an error in a future version of HDMF.") with self.assertWarnsWith(UserWarning, msg): table.add_column(name='col2', description='column #2', data=[1, 2, 3], index=[1, 2]) self.assertEqual(table.col2.description, 'column #2') self.assertEqual(type(table.get('col2')), VectorIndex) # not VectorData def test_add_opt_column_mismatched_col_cls(self): """Test that adding an optional column from __columns__ with non-matched table raises a warning.""" table = SubTable(name='subtable', description='subtable description') msg = ("Column 'col10' is predefined in SubTable with class= " "which does not match the entered col_cls " "argument. The predefined class spec will be ignored. " "Please ensure the new column complies with the spec. " "This will raise an error in a future version of HDMF.") with self.assertWarnsWith(UserWarning, msg): table.add_column(name='col10', description='column #10', index=True) self.assertEqual(table.col10.description, 'column #10') self.assertEqual(type(table.col10), VectorData) self.assertEqual(type(table.get('col10')), VectorIndex) def test_add_opt_column_twice(self): """Test that adding an optional column from __columns__ twice fails the second time.""" table = SubTable(name='subtable', description='subtable description') table.add_column(name='col2', description='column #2') msg = "column 'col2' already exists in SubTable 'subtable'" with self.assertRaisesWith(ValueError, msg): table.add_column(name='col2', description='column #2b') def test_add_opt_column_after_data(self): """Test that adding an optional column from __columns__ with data works.""" table = SubTable(name='subtable', description='subtable description') table.add_row(col1='a', col3='c', col5='e', col7='g') table.add_column(name='col2', description='column #2', data=('b', )) self.assertTupleEqual(table.col2.data, ('b', )) def test_add_opt_ind_column_after_data(self): """Test that adding an optional, indexed column from __columns__ with data works.""" table = SubTable(name='subtable', description='subtable description') table.add_row(col1='a', col3='c', col5='e', col7='g') # TODO this use case is tricky and should not be allowed # table.add_column(name='col4', description='column #4', data=(('b', 'b2'), )) def test_add_row_opt_column(self): """Test that adding a row with an optional column works.""" table = SubTable(name='subtable', description='subtable description') table.add_row(col1='a', col2='b', col3='c', col4=('d1', 'd2'), col5='e', col7='g') table.add_row(col1='a', col2='b2', col3='c', col4=('d3', 'd4'), col5='e', col7='g') self.assertTupleEqual(table.colnames, ('col1', 'col3', 'col5', 'col7', 'col2', 'col4')) self.assertEqual(table.col2.description, 'optional column') self.assertEqual(table.col4.description, 'optional, indexed column') self.assertListEqual(table.col2.data, ['b', 'b2']) # self.assertListEqual(table.col4.data, [('d1', 'd2'), ('d3', 'd4')]) # TODO this should work def test_add_row_opt_column_after_data(self): """Test that adding a row with an optional column after adding a row without the column raises an error.""" table = SubTable(name='subtable', description='subtable description') table.add_row(col1='a', col3='c', col5='e', col7='g') msg = "column must have the same number of rows as 'id'" # TODO improve error message with self.assertRaisesWith(ValueError, msg): table.add_row(col1='a', col2='b', col3='c', col5='e', col7='g') def test_init_columns_add_req_column(self): """Test that passing a required column to init works.""" col1 = VectorData(name='col1', description='column #1') # override __columns__ description table = SubTable(name='subtable', description='subtable description', columns=[col1]) self.assertEqual(table.colnames, ('col1', 'col3', 'col5', 'col7')) self.assertEqual(table.col1.description, 'column #1') self.assertTrue(hasattr(table, 'col1')) def test_init_columns_add_req_column_mismatch_index(self): """Test that passing a required column that does not match the predefined column specs raises an error.""" col1 = VectorData(name='col1', description='column #1') # override __columns__ description col1_ind = VectorIndex(name='col1_index', data=list(), target=col1) # TODO raise an error SubTable(name='subtable', description='subtable description', columns=[col1_ind, col1]) def test_init_columns_add_req_column_mismatch_table(self): """Test that passing a required column that does not match the predefined column specs raises an error.""" dummy_table = DynamicTable(name='dummy', description='dummy table') col1 = DynamicTableRegion(name='col1', data=list(), description='column #1', table=dummy_table) # TODO raise an error SubTable(name='subtable', description='subtable description', columns=[col1]) def test_init_columns_add_opt_column(self): """Test that passing an optional column to init works.""" col2 = VectorData(name='col2', description='column #2') # override __columns__ description table = SubTable(name='subtable', description='subtable description', columns=[col2]) self.assertEqual(table.colnames, ('col2', 'col1', 'col3', 'col5', 'col7')) self.assertEqual(table.col2.description, 'column #2') def test_init_columns_add_dup_column(self): """Test that passing two columns with the same name raises an error.""" col1 = VectorData(name='col1', description='column #1') # override __columns__ description col1_ind = VectorIndex(name='col1', data=list(), target=col1) msg = "'columns' contains columns with duplicate names: ['col1', 'col1']" with self.assertRaisesWith(ValueError, msg): SubTable(name='subtable', description='subtable description', columns=[col1_ind, col1]) class TestEnumData(TestCase): def test_init(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c'], data=np.array([0, 0, 1, 1, 2, 2])) self.assertIsInstance(ed.elements, VectorData) def test_get(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c'], data=np.array([0, 0, 1, 1, 2, 2])) dat = ed[2] self.assertEqual(dat, 'b') dat = ed[-1] self.assertEqual(dat, 'c') dat = ed[0] self.assertEqual(dat, 'a') def test_get_list(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c'], data=np.array([0, 0, 1, 1, 2, 2])) dat = ed[[0, 1, 2]] np.testing.assert_array_equal(dat, ['a', 'a', 'b']) def test_get_list_join(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c'], data=np.array([0, 0, 1, 1, 2, 2])) dat = ed.get([0, 1, 2], join=True) self.assertEqual(dat, 'aab') def test_get_list_indices(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c'], data=np.array([0, 0, 1, 1, 2, 2])) dat = ed.get([0, 1, 2], index=True) np.testing.assert_array_equal(dat, [0, 0, 1]) def test_get_2d(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c'], data=np.array([[0, 0], [1, 1], [2, 2]])) dat = ed[0] np.testing.assert_array_equal(dat, ['a', 'a']) def test_get_2d_w_2d(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c'], data=np.array([[0, 0], [1, 1], [2, 2]])) dat = ed[[0, 1]] np.testing.assert_array_equal(dat, [['a', 'a'], ['b', 'b']]) def test_add_row(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c']) ed.add_row('b') ed.add_row('a') ed.add_row('c') np.testing.assert_array_equal(ed.data, np.array([1, 0, 2], dtype=np.uint8)) def test_add_row_index(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c']) ed.add_row(1, index=True) ed.add_row(0, index=True) ed.add_row(2, index=True) np.testing.assert_array_equal(ed.data, np.array([1, 0, 2], dtype=np.uint8)) class TestIndexedEnumData(TestCase): def test_init(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c'], data=np.array([0, 0, 1, 1, 2, 2])) idx = VectorIndex('enum_index', [2, 4, 6], target=ed) self.assertIsInstance(ed.elements, VectorData) self.assertIsInstance(idx.target, EnumData) def test_add_row(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c']) idx = VectorIndex('enum_index', list(), target=ed) idx.add_row(['a', 'a', 'a']) idx.add_row(['b', 'b']) idx.add_row(['c', 'c', 'c', 'c']) np.testing.assert_array_equal(idx[0], ['a', 'a', 'a']) np.testing.assert_array_equal(idx[1], ['b', 'b']) np.testing.assert_array_equal(idx[2], ['c', 'c', 'c', 'c']) def test_add_row_index(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c']) idx = VectorIndex('enum_index', list(), target=ed) idx.add_row([0, 0, 0], index=True) idx.add_row([1, 1], index=True) idx.add_row([2, 2, 2, 2], index=True) np.testing.assert_array_equal(idx[0], ['a', 'a', 'a']) np.testing.assert_array_equal(idx[1], ['b', 'b']) np.testing.assert_array_equal(idx[2], ['c', 'c', 'c', 'c']) @unittest.skip("feature is not yet supported") def test_add_2d_row_index(self): ed = EnumData('cv_data', 'a test EnumData', elements=['a', 'b', 'c']) idx = VectorIndex('enum_index', list(), target=ed) idx.add_row([['a', 'a'], ['a', 'a'], ['a', 'a']]) idx.add_row([['b', 'b'], ['b', 'b']]) idx.add_row([['c', 'c'], ['c', 'c'], ['c', 'c'], ['c', 'c']]) np.testing.assert_array_equal(idx[0], [['a', 'a'], ['a', 'a'], ['a', 'a']]) np.testing.assert_array_equal(idx[1], [['b', 'b'], ['b', 'b']]) np.testing.assert_array_equal(idx[2], [['c', 'c'], ['c', 'c'], ['c', 'c'], ['c', 'c']]) class SelectionTestMixin: def setUp(self): # table1 contains a non-ragged DTR and a ragged DTR, both of which point to table2 # table2 contains a non-ragged DTR and a ragged DTR, both of which point to table3 self.table3 = DynamicTable( name='table3', description='a test table', id=[20, 21, 22] ) self.table3.add_column('foo', 'scalar column', data=self._wrap([20.0, 21.0, 22.0])) self.table3.add_column('bar', 'ragged column', index=self._wrap([2, 3, 6]), data=self._wrap(['t11', 't12', 't21', 't31', 't32', 't33'])) self.table3.add_column('baz', 'multi-dimension column', data=self._wrap([[210.0, 211.0, 212.0], [220.0, 221.0, 222.0], [230.0, 231.0, 232.0]])) # generate expected dataframe for table3 data = OrderedDict() data['foo'] = [20.0, 21.0, 22.0] data['bar'] = [['t11', 't12'], ['t21'], ['t31', 't32', 't33']] data['baz'] = [[210.0, 211.0, 212.0], [220.0, 221.0, 222.0], [230.0, 231.0, 232.0]] idx = [20, 21, 22] self.table3_df = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) self.table2 = DynamicTable( name='table2', description='a test table', id=[10, 11, 12] ) self.table2.add_column('foo', 'scalar column', data=self._wrap([10.0, 11.0, 12.0])) self.table2.add_column('bar', 'ragged column', index=self._wrap([2, 3, 6]), data=self._wrap(['s11', 's12', 's21', 's31', 's32', 's33'])) self.table2.add_column('baz', 'multi-dimension column', data=self._wrap([[110.0, 111.0, 112.0], [120.0, 121.0, 122.0], [130.0, 131.0, 132.0]])) self.table2.add_column('qux', 'DTR column', table=self.table3, data=self._wrap([0, 1, 0])) self.table2.add_column('corge', 'ragged DTR column', index=self._wrap([2, 3, 6]), table=self.table3, data=self._wrap([0, 1, 2, 0, 1, 2])) # TODO test when ragged DTR indices are not in ascending order # generate expected dataframe for table2 *without DTR* data = OrderedDict() data['foo'] = [10.0, 11.0, 12.0] data['bar'] = [['s11', 's12'], ['s21'], ['s31', 's32', 's33']] data['baz'] = [[110.0, 111.0, 112.0], [120.0, 121.0, 122.0], [130.0, 131.0, 132.0]] idx = [10, 11, 12] self.table2_df = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) self.table1 = DynamicTable( name='table1', description='a table to test slicing', id=[0, 1] ) self.table1.add_column('foo', 'scalar column', data=self._wrap([0.0, 1.0])) self.table1.add_column('bar', 'ragged column', index=self._wrap([2, 3]), data=self._wrap(['r11', 'r12', 'r21'])) self.table1.add_column('baz', 'multi-dimension column', data=self._wrap([[10.0, 11.0, 12.0], [20.0, 21.0, 22.0]])) self.table1.add_column('qux', 'DTR column', table=self.table2, data=self._wrap([0, 1])) self.table1.add_column('corge', 'ragged DTR column', index=self._wrap([2, 3]), table=self.table2, data=self._wrap([0, 1, 2])) self.table1.add_column('barz', 'ragged column of tuples (cpd type)', index=self._wrap([2, 3]), data=self._wrap([(1.0, 11), (2.0, 12), (3.0, 21)])) # generate expected dataframe for table1 *without DTR* data = OrderedDict() data['foo'] = self._wrap_check([0.0, 1.0]) data['bar'] = [self._wrap_check(['r11', 'r12']), self._wrap_check(['r21'])] data['baz'] = [self._wrap_check([10.0, 11.0, 12.0]), self._wrap_check([20.0, 21.0, 22.0])] data['barz'] = [self._wrap_check([(1.0, 11), (2.0, 12)]), self._wrap_check([(3.0, 21)])] idx = [0, 1] self.table1_df = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) def _check_two_rows_df(self, rec): data = OrderedDict() data['foo'] = self._wrap_check([0.0, 1.0]) data['bar'] = [self._wrap_check(['r11', 'r12']), self._wrap_check(['r21'])] data['baz'] = [self._wrap_check([10.0, 11.0, 12.0]), self._wrap_check([20.0, 21.0, 22.0])] data['qux'] = self._wrap_check([0, 1]) data['corge'] = [self._wrap_check([0, 1]), self._wrap_check([2])] data['barz'] = [self._wrap_check([(1.0, 11), (2.0, 12)]), self._wrap_check([(3.0, 21)])] idx = [0, 1] exp = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) pd.testing.assert_frame_equal(rec, exp) def _check_two_rows_df_nested(self, rec): # first level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal qux_series = rec['qux'] corge_series = rec['corge'] del rec['qux'] del rec['corge'] idx = [0, 1] pd.testing.assert_frame_equal(rec, self.table1_df.loc[idx]) # second level: compare the nested columns separately self.assertEqual(len(qux_series), 2) rec_qux1 = qux_series[0] rec_qux2 = qux_series[1] self._check_table2_first_row_qux(rec_qux1) self._check_table2_second_row_qux(rec_qux2) self.assertEqual(len(corge_series), 2) rec_corge1 = corge_series[0] rec_corge2 = corge_series[1] self._check_table2_first_row_corge(rec_corge1) self._check_table2_second_row_corge(rec_corge2) def _check_one_row_df(self, rec): data = OrderedDict() data['foo'] = self._wrap_check([0.0]) data['bar'] = [self._wrap_check(['r11', 'r12'])] data['baz'] = [self._wrap_check([10.0, 11.0, 12.0])] data['qux'] = self._wrap_check([0]) data['corge'] = [self._wrap_check([0, 1])] data['barz'] = [self._wrap_check([(1.0, 11), (2.0, 12)])] idx = [0] exp = pd.DataFrame(data=data, index=pd.Index(name='id', data=idx)) pd.testing.assert_frame_equal(rec, exp) def _check_one_row_df_nested(self, rec): # first level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal qux_series = rec['qux'] corge_series = rec['corge'] del rec['qux'] del rec['corge'] idx = [0] pd.testing.assert_frame_equal(rec, self.table1_df.loc[idx]) # second level: compare the nested columns separately self.assertEqual(len(qux_series), 1) rec_qux = qux_series[0] self._check_table2_first_row_qux(rec_qux) self.assertEqual(len(corge_series), 1) rec_corge = corge_series[0] self._check_table2_first_row_corge(rec_corge) def _check_table2_first_row_qux(self, rec_qux): # second level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal qux_qux_series = rec_qux['qux'] qux_corge_series = rec_qux['corge'] del rec_qux['qux'] del rec_qux['corge'] qux_idx = [10] pd.testing.assert_frame_equal(rec_qux, self.table2_df.loc[qux_idx]) # third level: compare the nested columns separately self.assertEqual(len(qux_qux_series), 1) pd.testing.assert_frame_equal(qux_qux_series[qux_idx[0]], self.table3_df.iloc[[0]]) self.assertEqual(len(qux_corge_series), 1) pd.testing.assert_frame_equal(qux_corge_series[qux_idx[0]], self.table3_df.iloc[[0, 1]]) def _check_table2_second_row_qux(self, rec_qux): # second level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal qux_qux_series = rec_qux['qux'] qux_corge_series = rec_qux['corge'] del rec_qux['qux'] del rec_qux['corge'] qux_idx = [11] pd.testing.assert_frame_equal(rec_qux, self.table2_df.loc[qux_idx]) # third level: compare the nested columns separately self.assertEqual(len(qux_qux_series), 1) pd.testing.assert_frame_equal(qux_qux_series[qux_idx[0]], self.table3_df.iloc[[1]]) self.assertEqual(len(qux_corge_series), 1) pd.testing.assert_frame_equal(qux_corge_series[qux_idx[0]], self.table3_df.iloc[[2]]) def _check_table2_first_row_corge(self, rec_corge): # second level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal corge_qux_series = rec_corge['qux'] corge_corge_series = rec_corge['corge'] del rec_corge['qux'] del rec_corge['corge'] corge_idx = [10, 11] pd.testing.assert_frame_equal(rec_corge, self.table2_df.loc[corge_idx]) # third level: compare the nested columns separately self.assertEqual(len(corge_qux_series), 2) pd.testing.assert_frame_equal(corge_qux_series[corge_idx[0]], self.table3_df.iloc[[0]]) pd.testing.assert_frame_equal(corge_qux_series[corge_idx[1]], self.table3_df.iloc[[1]]) self.assertEqual(len(corge_corge_series), 2) pd.testing.assert_frame_equal(corge_corge_series[corge_idx[0]], self.table3_df.iloc[[0, 1]]) pd.testing.assert_frame_equal(corge_corge_series[corge_idx[1]], self.table3_df.iloc[[2]]) def _check_table2_second_row_corge(self, rec_corge): # second level: cache nested df cols and remove them before calling pd.testing.assert_frame_equal corge_qux_series = rec_corge['qux'] corge_corge_series = rec_corge['corge'] del rec_corge['qux'] del rec_corge['corge'] corge_idx = [12] pd.testing.assert_frame_equal(rec_corge, self.table2_df.loc[corge_idx]) # third level: compare the nested columns separately self.assertEqual(len(corge_qux_series), 1) pd.testing.assert_frame_equal(corge_qux_series[corge_idx[0]], self.table3_df.iloc[[0]]) self.assertEqual(len(corge_corge_series), 1) pd.testing.assert_frame_equal(corge_corge_series[corge_idx[0]], self.table3_df.iloc[[0, 1, 2]]) def _check_two_rows_no_df(self, rec): self.assertEqual(rec[0], [0, 1]) np.testing.assert_array_equal(rec[1], self._wrap_check([0.0, 1.0])) expected = [self._wrap_check(['r11', 'r12']), self._wrap_check(['r21'])] self._assertNestedRaggedArrayEqual(rec[2], expected) np.testing.assert_array_equal(rec[3], self._wrap_check([[10.0, 11.0, 12.0], [20.0, 21.0, 22.0]])) np.testing.assert_array_equal(rec[4], self._wrap_check([0, 1])) expected = [self._wrap_check([0, 1]), self._wrap_check([2])] for i, j in zip(rec[5], expected): np.testing.assert_array_equal(i, j) def _check_one_row_no_df(self, rec): self.assertEqual(rec[0], 0) self.assertEqual(rec[1], 0.0) np.testing.assert_array_equal(rec[2], self._wrap_check(['r11', 'r12'])) np.testing.assert_array_equal(rec[3], self._wrap_check([10.0, 11.0, 12.0])) self.assertEqual(rec[4], 0) np.testing.assert_array_equal(rec[5], self._wrap_check([0, 1])) np.testing.assert_array_equal(rec[6], self._wrap_check([(1.0, 11), (2.0, 12)])) def _check_one_row_multiselect_no_df(self, rec): # difference from _check_one_row_no_df is that everything is wrapped in a list self.assertEqual(rec[0], [0]) self.assertEqual(rec[1], [0.0]) np.testing.assert_array_equal(rec[2], [self._wrap_check(['r11', 'r12'])]) np.testing.assert_array_equal(rec[3], [self._wrap_check([10.0, 11.0, 12.0])]) self.assertEqual(rec[4], [0]) np.testing.assert_array_equal(rec[5], [self._wrap_check([0, 1])]) np.testing.assert_array_equal(rec[6], [self._wrap_check([(1.0, 11), (2.0, 12)])]) def _assertNestedRaggedArrayEqual(self, arr1, arr2): """ This is a helper function for _check_two_rows_no_df. It compares arrays or lists containing numpy arrays that may be ragged """ self.assertEqual(type(arr1), type(arr2)) self.assertEqual(len(arr1), len(arr2)) if isinstance(arr1, np.ndarray): if arr1.dtype == object: # both are arrays containing arrays, lists, or h5py.Dataset strings for i, j in zip(arr1, arr2): self._assertNestedRaggedArrayEqual(i, j) elif np.issubdtype(arr1.dtype, np.number): np.testing.assert_allclose(arr1, arr2) else: np.testing.assert_array_equal(arr1, arr2) elif isinstance(arr1, list): for i, j in zip(arr1, arr2): self._assertNestedRaggedArrayEqual(i, j) else: # scalar self.assertEqual(arr1, arr2) def test_single_item(self): rec = self.table1[0] self._check_one_row_df(rec) def test_single_item_nested(self): rec = self.table1.get(0, index=False) self._check_one_row_df_nested(rec) def test_single_item_no_df(self): rec = self.table1.get(0, df=False) self._check_one_row_no_df(rec) def test_slice(self): rec = self.table1[0:2] self._check_two_rows_df(rec) def test_slice_nested(self): rec = self.table1.get(slice(0, 2), index=False) self._check_two_rows_df_nested(rec) def test_slice_no_df(self): rec = self.table1.get(slice(0, 2), df=False) self._check_two_rows_no_df(rec) def test_slice_single(self): rec = self.table1[0:1] self._check_one_row_df(rec) def test_slice_single_nested(self): rec = self.table1.get(slice(0, 1), index=False) self._check_one_row_df_nested(rec) def test_slice_single_no_df(self): rec = self.table1.get(slice(0, 1), df=False) self._check_one_row_multiselect_no_df(rec) def test_list(self): rec = self.table1[[0, 1]] self._check_two_rows_df(rec) def test_list_nested(self): rec = self.table1.get([0, 1], index=False) self._check_two_rows_df_nested(rec) def test_list_no_df(self): rec = self.table1.get([0, 1], df=False) self._check_two_rows_no_df(rec) def test_list_single(self): rec = self.table1[[0]] self._check_one_row_df(rec) def test_list_single_nested(self): rec = self.table1.get([0], index=False) self._check_one_row_df_nested(rec) def test_list_single_no_df(self): rec = self.table1.get([0], df=False) self._check_one_row_multiselect_no_df(rec) def test_array(self): rec = self.table1[np.array([0, 1])] self._check_two_rows_df(rec) def test_array_nested(self): rec = self.table1.get(np.array([0, 1]), index=False) self._check_two_rows_df_nested(rec) def test_array_no_df(self): rec = self.table1.get(np.array([0, 1]), df=False) self._check_two_rows_no_df(rec) def test_array_single(self): rec = self.table1[np.array([0])] self._check_one_row_df(rec) def test_array_single_nested(self): rec = self.table1.get(np.array([0]), index=False) self._check_one_row_df_nested(rec) def test_array_single_no_df(self): rec = self.table1.get(np.array([0]), df=False) self._check_one_row_multiselect_no_df(rec) def test_to_dataframe_nested(self): rec = self.table1.to_dataframe() self._check_two_rows_df_nested(rec) def test_to_dataframe(self): rec = self.table1.to_dataframe(index=True) self._check_two_rows_df(rec) class TestSelectionArray(SelectionTestMixin, TestCase): def _wrap(self, my_list): return np.array(my_list) def _wrap_check(self, my_list): return self._wrap(my_list) class TestSelectionList(SelectionTestMixin, TestCase): def _wrap(self, my_list): return my_list def _wrap_check(self, my_list): return self._wrap(my_list) class TestSelectionH5Dataset(SelectionTestMixin, TestCase): def setUp(self): self.path = get_temp_filepath() self.file = h5py.File(self.path, 'w') self.dset_counter = 0 super().setUp() def tearDown(self): super().tearDown() self.file.close() if os.path.exists(self.path): os.remove(self.path) def _wrap(self, my_list): self.dset_counter = self.dset_counter + 1 kwargs = dict() if isinstance(my_list[0], str): kwargs['dtype'] = H5_TEXT elif isinstance(my_list[0], tuple): # compound dtype # normally for cpd dtype, __resolve_dtype__ takes a list of DtypeSpec objects cpd_type = [dict(name='cpd_float', dtype=np.dtype('float64')), dict(name='cpd_int', dtype=np.dtype('int32'))] kwargs['dtype'] = HDF5IO.__resolve_dtype__(cpd_type, my_list[0]) dset = self.file.create_dataset('dset%d' % self.dset_counter, data=np.array(my_list, **kwargs)) if H5PY_3 and isinstance(my_list[0], str): return StrDataset(dset, None) # return a wrapper to read data as str instead of bytes else: # NOTE: h5py.Dataset with compound dtype are read as numpy arrays with compound dtype, not tuples return dset def _wrap_check(self, my_list): # getitem on h5dataset backed data will return np.array kwargs = dict() if isinstance(my_list[0], str): kwargs['dtype'] = H5_TEXT elif isinstance(my_list[0], tuple): cpd_type = [dict(name='cpd_float', dtype=np.dtype('float64')), dict(name='cpd_int', dtype=np.dtype('int32'))] kwargs['dtype'] = np.dtype([(x['name'], x['dtype']) for x in cpd_type]) # compound dtypes with str are read as bytes, see https://github.com/h5py/h5py/issues/1751 return np.array(my_list, **kwargs) class TestVectorIndex(TestCase): def test_init_empty(self): foo = VectorData(name='foo', description='foo column') foo_ind = VectorIndex(name='foo_index', target=foo, data=list()) self.assertEqual(foo_ind.name, 'foo_index') self.assertEqual(foo_ind.description, "Index for VectorData 'foo'") self.assertIs(foo_ind.target, foo) self.assertListEqual(foo_ind.data, list()) def test_init_data(self): foo = VectorData(name='foo', description='foo column', data=['a', 'b', 'c']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3]) self.assertListEqual(foo_ind.data, [2, 3]) self.assertListEqual(foo_ind[0], ['a', 'b']) self.assertListEqual(foo_ind[1], ['c']) class TestDoubleIndex(TestCase): def test_index(self): # row 1 has three entries # the first entry has two sub-entries # the first sub-entry has two values, the second sub-entry has one value # the second entry has one sub-entry, which has one value foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) self.assertListEqual(foo_ind[0], ['a11', 'a12']) self.assertListEqual(foo_ind[1], ['a21']) self.assertListEqual(foo_ind[2], ['b11']) self.assertListEqual(foo_ind_ind[0], [['a11', 'a12'], ['a21']]) self.assertListEqual(foo_ind_ind[1], [['b11']]) def test_add_vector(self): # row 1 has three entries # the first entry has two sub-entries # the first sub-entry has two values, the second sub-entry has one value # the second entry has one sub-entry, which has one value foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) foo_ind_ind.add_vector([['c11', 'c12', 'c13'], ['c21', 'c22']]) self.assertListEqual(foo.data, ['a11', 'a12', 'a21', 'b11', 'c11', 'c12', 'c13', 'c21', 'c22']) self.assertListEqual(foo_ind.data, [2, 3, 4, 7, 9]) self.assertListEqual(foo_ind[3], ['c11', 'c12', 'c13']) self.assertListEqual(foo_ind[4], ['c21', 'c22']) self.assertListEqual(foo_ind_ind.data, [2, 3, 5]) self.assertListEqual(foo_ind_ind[2], [['c11', 'c12', 'c13'], ['c21', 'c22']]) class TestDTDoubleIndex(TestCase): def test_double_index(self): foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) table = DynamicTable('table0', 'an example table', columns=[foo, foo_ind, foo_ind_ind]) self.assertIs(table['foo'], foo_ind_ind) self.assertIs(table.foo, foo) self.assertListEqual(table['foo'][0], [['a11', 'a12'], ['a21']]) self.assertListEqual(table[0, 'foo'], [['a11', 'a12'], ['a21']]) self.assertListEqual(table[1, 'foo'], [['b11']]) def test_double_index_reverse(self): foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) table = DynamicTable('table0', 'an example table', columns=[foo_ind_ind, foo_ind, foo]) self.assertIs(table['foo'], foo_ind_ind) self.assertIs(table.foo, foo) self.assertListEqual(table['foo'][0], [['a11', 'a12'], ['a21']]) self.assertListEqual(table[0, 'foo'], [['a11', 'a12'], ['a21']]) self.assertListEqual(table[1, 'foo'], [['b11']]) def test_double_index_colnames(self): foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) bar = VectorData(name='bar', description='bar column', data=[1, 2]) table = DynamicTable('table0', 'an example table', columns=[foo, foo_ind, foo_ind_ind, bar], colnames=['foo', 'bar']) self.assertTupleEqual(table.columns, (foo_ind_ind, foo_ind, foo, bar)) def test_double_index_reverse_colnames(self): foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) bar = VectorData(name='bar', description='bar column', data=[1, 2]) table = DynamicTable('table0', 'an example table', columns=[foo_ind_ind, foo_ind, foo, bar], colnames=['bar', 'foo']) self.assertTupleEqual(table.columns, (bar, foo_ind_ind, foo_ind, foo)) class TestDTDoubleIndexSkipMiddle(TestCase): def test_index(self): foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) msg = "Found VectorIndex 'foo_index_index' but not its target 'foo_index'" with self.assertRaisesWith(ValueError, msg): DynamicTable('table0', 'an example table', columns=[foo_ind_ind, foo]) class TestDynamicTableAddIndexRoundTrip(H5RoundTripMixin, TestCase): def setUpContainer(self): table = DynamicTable('table0', 'an example table') table.add_column('foo', 'an int column', index=True) table.add_row(foo=[1, 2, 3]) return table class TestDynamicTableAddEnumRoundTrip(H5RoundTripMixin, TestCase): def setUpContainer(self): table = DynamicTable('table0', 'an example table') table.add_column('bar', 'an enumerable column', enum=True) table.add_row(bar='a') table.add_row(bar='b') table.add_row(bar='a') table.add_row(bar='c') return table class TestDynamicTableAddEnum(TestCase): def test_enum(self): table = DynamicTable('table0', 'an example table') table.add_column('bar', 'an enumerable column', enum=True) table.add_row(bar='a') table.add_row(bar='b') table.add_row(bar='a') table.add_row(bar='c') rec = table.to_dataframe() exp = pd.DataFrame(data={'bar': ['a', 'b', 'a', 'c']}, index=pd.Series(name='id', data=[0, 1, 2, 3])) pd.testing.assert_frame_equal(exp, rec) def test_enum_index(self): table = DynamicTable('table0', 'an example table') table.add_column('bar', 'an indexed enumerable column', enum=True, index=True) table.add_row(bar=['a', 'a', 'a']) table.add_row(bar=['b', 'b', 'b', 'b']) table.add_row(bar=['c', 'c']) rec = table.to_dataframe() exp = pd.DataFrame(data={'bar': [['a', 'a', 'a'], ['b', 'b', 'b', 'b'], ['c', 'c']]}, index=pd.Series(name='id', data=[0, 1, 2])) pd.testing.assert_frame_equal(exp, rec) class TestDynamicTableInitIndexRoundTrip(H5RoundTripMixin, TestCase): def setUpContainer(self): foo = VectorData(name='foo', description='foo column', data=['a', 'b', 'c']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3]) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list table = DynamicTable('table0', 'an example table', columns=[foo_ind, foo]) return table class TestDoubleIndexRoundtrip(H5RoundTripMixin, TestCase): def setUpContainer(self): foo = VectorData(name='foo', description='foo column', data=['a11', 'a12', 'a21', 'b11']) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) foo_ind_ind = VectorIndex(name='foo_index_index', target=foo_ind, data=[2, 3]) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list table = DynamicTable('table0', 'an example table', columns=[foo_ind_ind, foo_ind, foo]) return table class TestDataIOColumns(H5RoundTripMixin, TestCase): def setUpContainer(self): self.chunked_data = H5DataIO( data=[i for i in range(10)], chunks=(3,), fillvalue=-1, ) self.compressed_data = H5DataIO( data=np.arange(10), compression=1, shuffle=True, fletcher32=True, allow_plugin_filters=True, ) foo = VectorData(name='foo', description='chunked column', data=self.chunked_data) bar = VectorData(name='bar', description='chunked column', data=self.compressed_data) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list table = DynamicTable('table0', 'an example table', columns=[foo, bar]) table.add_row(foo=1, bar=1) return table def test_roundtrip(self): super().test_roundtrip() with h5py.File(self.filename, 'r') as f: chunked_dset = f['foo'] self.assertTrue(np.all(chunked_dset[:] == self.chunked_data.data)) self.assertEqual(chunked_dset.chunks, (3,)) self.assertEqual(chunked_dset.fillvalue, -1) compressed_dset = f['bar'] self.assertTrue(np.all(compressed_dset[:] == self.compressed_data.data)) self.assertEqual(compressed_dset.compression, 'gzip') self.assertEqual(compressed_dset.shuffle, True) self.assertEqual(compressed_dset.fletcher32, True) class TestDataIOIndexedColumns(H5RoundTripMixin, TestCase): def setUpContainer(self): self.chunked_data = H5DataIO( data=np.arange(30).reshape(5, 2, 3), chunks=(1, 1, 3), fillvalue=-1, ) self.compressed_data = H5DataIO( data=np.arange(30).reshape(5, 2, 3), compression=1, shuffle=True, fletcher32=True, allow_plugin_filters=True, ) foo = VectorData(name='foo', description='chunked column', data=self.chunked_data) foo_ind = VectorIndex(name='foo_index', target=foo, data=[2, 3, 4]) bar = VectorData(name='bar', description='chunked column', data=self.compressed_data) bar_ind = VectorIndex(name='bar_index', target=bar, data=[2, 3, 4]) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list table = DynamicTable('table0', 'an example table', columns=[foo_ind, foo, bar_ind, bar]) # check for add_row table.add_row(foo=np.arange(30).reshape(5, 2, 3), bar=np.arange(30).reshape(5, 2, 3)) return table def test_roundtrip(self): super().test_roundtrip() with h5py.File(self.filename, 'r') as f: chunked_dset = f['foo'] self.assertTrue(np.all(chunked_dset[:] == self.chunked_data.data)) self.assertEqual(chunked_dset.chunks, (1, 1, 3)) self.assertEqual(chunked_dset.fillvalue, -1) compressed_dset = f['bar'] self.assertTrue(np.all(compressed_dset[:] == self.compressed_data.data)) self.assertEqual(compressed_dset.compression, 'gzip') self.assertEqual(compressed_dset.shuffle, True) self.assertEqual(compressed_dset.fletcher32, True) class TestDataIOIndex(H5RoundTripMixin, TestCase): def setUpContainer(self): self.chunked_data = H5DataIO( data=np.arange(30).reshape(5, 2, 3), chunks=(1, 1, 3), fillvalue=-1, maxshape=(None, 2, 3) ) self.chunked_index_data = H5DataIO( data=np.array([2, 3, 5], dtype=np.uint), chunks=(2, ), fillvalue=np.uint(10), maxshape=(None,) ) self.compressed_data = H5DataIO( data=np.arange(30).reshape(5, 2, 3), compression=1, shuffle=True, fletcher32=True, allow_plugin_filters=True, maxshape=(None, 2, 3) ) self.compressed_index_data = H5DataIO( data=np.array([2, 4, 5], dtype=np.uint), compression=1, shuffle=True, fletcher32=False, allow_plugin_filters=True, maxshape=(None,) ) foo = VectorData(name='foo', description='chunked column', data=self.chunked_data) foo_ind = VectorIndex(name='foo_index', target=foo, data=self.chunked_index_data) bar = VectorData(name='bar', description='chunked column', data=self.compressed_data) bar_ind = VectorIndex(name='bar_index', target=bar, data=self.compressed_index_data) # NOTE: on construct, columns are ordered such that indices go before data, so create the table that way # for proper comparison of the columns list table = DynamicTable('table0', 'an example table', columns=[foo_ind, foo, bar_ind, bar], id=H5DataIO(data=[0, 1, 2], chunks=True, maxshape=(None,))) # check for add_row table.add_row(foo=np.arange(30).reshape(5, 2, 3), bar=np.arange(30).reshape(5, 2, 3)) return table def test_append(self, cache_spec=False): """Write the container to an HDF5 file, read the container from the file, and append to it.""" with HDF5IO(self.filename, manager=get_manager(), mode='w') as write_io: write_io.write(self.container, cache_spec=cache_spec) self.reader = HDF5IO(self.filename, manager=get_manager(), mode='a') read_table = self.reader.read() data = np.arange(30, 60).reshape(5, 2, 3) read_table.add_row(foo=data, bar=data) np.testing.assert_array_equal(read_table['foo'][-1], data) class TestDTRReferences(TestCase): def setUp(self): self.filename = 'test_dtr_references.h5' def tearDown(self): remove_test_file(self.filename) def test_dtr_references(self): """Test roundtrip of a table with a ragged DTR to another table containing a column of references.""" group1 = Container('group1') group2 = Container('group2') table1 = DynamicTable( name='table1', description='test table 1' ) table1.add_column( name='x', description='test column of ints' ) table1.add_column( name='y', description='test column of reference' ) table1.add_row(id=101, x=1, y=group1) table1.add_row(id=102, x=2, y=group1) table1.add_row(id=103, x=3, y=group2) table2 = DynamicTable( name='table2', description='test table 2' ) # create a ragged column that references table1 # each row of table2 corresponds to one or more rows of table 1 table2.add_column( name='electrodes', description='column description', index=True, table=table1 ) table2.add_row(id=10, electrodes=[1, 2]) multi_container = SimpleMultiContainer('multi') multi_container.add_container(group1) multi_container.add_container(group2) multi_container.add_container(table1) multi_container.add_container(table2) with HDF5IO(self.filename, manager=get_manager(), mode='w') as io: io.write(multi_container) with HDF5IO(self.filename, manager=get_manager(), mode='r') as io: read_multi_container = io.read() self.assertContainerEqual(read_multi_container, multi_container, ignore_name=True) # test DTR access read_group1 = read_multi_container['group1'] read_group2 = read_multi_container['group2'] read_table = read_multi_container['table2'] ret = read_table[0, 'electrodes'] expected = pd.DataFrame({'x': np.array([2, 3]), 'y': [read_group1, read_group2]}, index=pd.Index(data=[102, 103], name='id')) pd.testing.assert_frame_equal(ret, expected) class TestVectorIndexDtype(TestCase): def set_up_array_index(self): data = VectorData(name='data', description='desc') index = VectorIndex(name='index', data=np.array([]), target=data) return index def set_up_list_index(self): data = VectorData(name='data', description='desc') index = VectorIndex(name='index', data=[], target=data) return index def test_array_inc_precision(self): index = self.set_up_array_index() index.add_vector(np.empty((255, ))) self.assertEqual(index.data[0], 255) self.assertEqual(index.data.dtype, np.uint8) def test_array_inc_precision_1step(self): index = self.set_up_array_index() index.add_vector(np.empty((65535, ))) self.assertEqual(index.data[0], 65535) self.assertEqual(index.data.dtype, np.uint16) def test_array_inc_precision_2steps(self): index = self.set_up_array_index() index.add_vector(np.empty((65536, ))) self.assertEqual(index.data[0], 65536) self.assertEqual(index.data.dtype, np.uint32) def test_array_prev_data_inc_precision_2steps(self): index = self.set_up_array_index() index.add_vector(np.empty((255, ))) # dtype is still uint8 index.add_vector(np.empty((65536, ))) self.assertEqual(index.data[0], 255) # make sure the 255 is upgraded self.assertEqual(index.data.dtype, np.uint32) def test_list_inc_precision(self): index = self.set_up_list_index() index.add_vector(list(range(255))) self.assertEqual(index.data[0], 255) self.assertEqual(type(index.data[0]), np.uint8) def test_list_inc_precision_1step(self): index = self.set_up_list_index() index.add_vector(list(range(65535))) self.assertEqual(index.data[0], 65535) self.assertEqual(type(index.data[0]), np.uint16) def test_list_inc_precision_2steps(self): index = self.set_up_list_index() index.add_vector(list(range(65536))) self.assertEqual(index.data[0], 65536) self.assertEqual(type(index.data[0]), np.uint32) def test_list_prev_data_inc_precision_2steps(self): index = self.set_up_list_index() index.add_vector(list(range(255))) index.add_vector(list(range(65536 - 255))) self.assertEqual(index.data[0], 255) # make sure the 255 is upgraded self.assertEqual(type(index.data[0]), np.uint32) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1846273 hdmf-3.1.1/tests/unit/spec_tests/0000755000655200065520000000000000000000000017056 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/__init__.py0000644000655200065520000000000000000000000021155 0ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test-ext.base.yaml0000644000655200065520000000061000000000000022425 0ustar00circlecicirclecidatasets: - my_data_type_def: TestExtData my_data_type_inc: TestData doc: An abstract data type for a dataset. groups: - my_data_type_def: TestExtContainer my_data_type_inc: Container doc: An abstract data type for a generic container storing collections of data and metadata. - my_data_type_def: TestExtTable my_data_type_inc: TestTable doc: An abstract data type for a table. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test-ext.namespace.yaml0000644000655200065520000000045500000000000023456 0ustar00circlecicirclecinamespaces: - name: test-ext doc: Test extension namespace author: - Test test contact: - test@test.com full_name: Test extension schema: - namespace: test - doc: This source module contains base data types. source: test-ext.base.yaml title: Base data types version: 0.1.0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test.base.yaml0000644000655200065520000000057600000000000021642 0ustar00circlecicirclecidatasets: - my_data_type_def: TestData my_data_type_inc: Data doc: An abstract data type for a dataset. groups: - my_data_type_def: TestContainer my_data_type_inc: Container doc: An abstract data type for a generic container storing collections of data and metadata. - my_data_type_def: TestTable my_data_type_inc: DynamicTable doc: An abstract data type for a table. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test.namespace.yaml0000644000655200065520000000053100000000000022653 0ustar00circlecicirclecinamespaces: - name: test doc: Test namespace author: - Test test contact: - test@test.com full_name: Test schema: - namespace: hdmf-common my_data_types: - Data - DynamicTable - Container - doc: This source module contains base data types. source: test.base.yaml title: Base data types version: 0.1.0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test_attribute_spec.py0000644000655200065520000000731300000000000023510 0ustar00circlecicircleciimport json from hdmf.spec import AttributeSpec, RefSpec from hdmf.testing import TestCase class AttributeSpecTests(TestCase): def test_constructor(self): spec = AttributeSpec('attribute1', 'my first attribute', 'text') self.assertEqual(spec['name'], 'attribute1') self.assertEqual(spec['dtype'], 'text') self.assertEqual(spec['doc'], 'my first attribute') self.assertIsNone(spec.parent) json.dumps(spec) # to ensure there are no circular links def test_invalid_dtype(self): with self.assertRaises(ValueError): AttributeSpec(name='attribute1', doc='my first attribute', dtype='invalid' # <-- Invalid dtype must raise a ValueError ) def test_both_value_and_default_value_set(self): with self.assertRaises(ValueError): AttributeSpec(name='attribute1', doc='my first attribute', dtype='int', value=5, default_value=10 # <-- Default_value and value can't be set at the same time ) def test_colliding_shape_and_dims(self): with self.assertRaises(ValueError): AttributeSpec(name='attribute1', doc='my first attribute', dtype='int', dims=['test'], shape=[None, 2] # <-- Length of shape and dims do not match must raise a ValueError ) def test_default_value(self): spec = AttributeSpec('attribute1', 'my first attribute', 'text', default_value='some text') self.assertEqual(spec['default_value'], 'some text') self.assertEqual(spec.default_value, 'some text') def test_shape(self): shape = [None, 2] spec = AttributeSpec('attribute1', 'my first attribute', 'text', shape=shape) self.assertEqual(spec['shape'], shape) self.assertEqual(spec.shape, shape) def test_dims_without_shape(self): spec = AttributeSpec('attribute1', 'my first attribute', 'text', dims=['test']) self.assertEqual(spec.shape, (None, )) def test_build_spec(self): spec_dict = {'name': 'attribute1', 'doc': 'my first attribute', 'dtype': 'text', 'shape': [None], 'dims': ['dim1'], 'value': ['a', 'b']} ret = AttributeSpec.build_spec(spec_dict) self.assertTrue(isinstance(ret, AttributeSpec)) self.assertDictEqual(ret, spec_dict) def test_build_spec_reftype(self): spec_dict = {'name': 'attribute1', 'doc': 'my first attribute', 'dtype': {'target_type': 'AnotherType', 'reftype': 'object'}} expected = spec_dict.copy() expected['dtype'] = RefSpec(target_type='AnotherType', reftype='object') ret = AttributeSpec.build_spec(spec_dict) self.assertTrue(isinstance(ret, AttributeSpec)) self.assertDictEqual(ret, expected) def test_build_spec_no_doc(self): spec_dict = {'name': 'attribute1', 'dtype': 'text'} msg = "AttributeSpec.__init__: missing argument 'doc'" with self.assertRaisesWith(TypeError, msg): AttributeSpec.build_spec(spec_dict) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test_dataset_spec.py0000644000655200065520000002630200000000000023131 0ustar00circlecicircleciimport json from hdmf.spec import GroupSpec, DatasetSpec, AttributeSpec, DtypeSpec, RefSpec from hdmf.testing import TestCase class DatasetSpecTests(TestCase): def setUp(self): self.attributes = [ AttributeSpec('attribute1', 'my first attribute', 'text'), AttributeSpec('attribute2', 'my second attribute', 'text') ] def test_constructor(self): spec = DatasetSpec('my first dataset', 'int', name='dataset1', attributes=self.attributes) self.assertEqual(spec['dtype'], 'int') self.assertEqual(spec['name'], 'dataset1') self.assertEqual(spec['doc'], 'my first dataset') self.assertNotIn('linkable', spec) self.assertNotIn('data_type_def', spec) self.assertListEqual(spec['attributes'], self.attributes) self.assertIs(spec, self.attributes[0].parent) self.assertIs(spec, self.attributes[1].parent) json.dumps(spec) def test_constructor_datatype(self): spec = DatasetSpec('my first dataset', 'int', name='dataset1', attributes=self.attributes, linkable=False, data_type_def='EphysData') self.assertEqual(spec['dtype'], 'int') self.assertEqual(spec['name'], 'dataset1') self.assertEqual(spec['doc'], 'my first dataset') self.assertEqual(spec['data_type_def'], 'EphysData') self.assertFalse(spec['linkable']) self.assertListEqual(spec['attributes'], self.attributes) self.assertIs(spec, self.attributes[0].parent) self.assertIs(spec, self.attributes[1].parent) def test_constructor_shape(self): shape = [None, 2] spec = DatasetSpec('my first dataset', 'int', name='dataset1', shape=shape, attributes=self.attributes) self.assertEqual(spec['shape'], shape) self.assertEqual(spec.shape, shape) def test_constructor_invalidate_dtype(self): with self.assertRaises(ValueError): DatasetSpec(doc='my first dataset', dtype='my bad dtype', # <-- Expect AssertionError due to bad type name='dataset1', dims=(None, None), attributes=self.attributes, linkable=False, data_type_def='EphysData') def test_constructor_ref_spec(self): dtype = RefSpec('TimeSeries', 'object') spec = DatasetSpec(doc='my first dataset', dtype=dtype, name='dataset1', dims=(None, None), attributes=self.attributes, linkable=False, data_type_def='EphysData') self.assertDictEqual(spec['dtype'], dtype) def test_datatype_extension(self): base = DatasetSpec('my first dataset', 'int', name='dataset1', attributes=self.attributes, linkable=False, data_type_def='EphysData') attributes = [AttributeSpec('attribute3', 'my first extending attribute', 'float')] ext = DatasetSpec('my first dataset extension', 'int', name='dataset1', attributes=attributes, linkable=False, data_type_inc=base, data_type_def='SpikeData') self.assertDictEqual(ext['attributes'][0], attributes[0]) self.assertDictEqual(ext['attributes'][1], self.attributes[0]) self.assertDictEqual(ext['attributes'][2], self.attributes[1]) ext_attrs = ext.attributes self.assertIs(ext, ext_attrs[0].parent) self.assertIs(ext, ext_attrs[1].parent) self.assertIs(ext, ext_attrs[2].parent) def test_datatype_extension_groupspec(self): '''Test to make sure DatasetSpec catches when a GroupSpec used as data_type_inc''' base = GroupSpec('a fake grop', data_type_def='EphysData') with self.assertRaises(TypeError): DatasetSpec('my first dataset extension', 'int', name='dataset1', data_type_inc=base, data_type_def='SpikeData') def test_constructor_table(self): dtype1 = DtypeSpec('column1', 'the first column', 'int') dtype2 = DtypeSpec('column2', 'the second column', 'float') spec = DatasetSpec('my first table', [dtype1, dtype2], name='table1', attributes=self.attributes) self.assertEqual(spec['dtype'], [dtype1, dtype2]) self.assertEqual(spec['name'], 'table1') self.assertEqual(spec['doc'], 'my first table') self.assertNotIn('linkable', spec) self.assertNotIn('data_type_def', spec) self.assertListEqual(spec['attributes'], self.attributes) self.assertIs(spec, self.attributes[0].parent) self.assertIs(spec, self.attributes[1].parent) json.dumps(spec) def test_constructor_invalid_table(self): with self.assertRaises(ValueError): DatasetSpec('my first table', [DtypeSpec('column1', 'the first column', 'int'), {} # <--- Bad compound type spec must raise an error ], name='table1', attributes=self.attributes) def test_constructor_default_value(self): spec = DatasetSpec(doc='test', default_value=5, dtype='int', data_type_def='test') self.assertEqual(spec.default_value, 5) def test_name_with_incompatible_quantity(self): # Check that we raise an error when the quantity allows more than one instance with a fixed name with self.assertRaises(ValueError): DatasetSpec(doc='my first dataset', dtype='int', name='ds1', quantity='zero_or_many') with self.assertRaises(ValueError): DatasetSpec(doc='my first dataset', dtype='int', name='ds1', quantity='one_or_many') def test_name_with_compatible_quantity(self): # Make sure compatible quantity flags pass when name is fixed DatasetSpec(doc='my first dataset', dtype='int', name='ds1', quantity='zero_or_one') DatasetSpec(doc='my first dataset', dtype='int', name='ds1', quantity=1) def test_datatype_table_extension(self): dtype1 = DtypeSpec('column1', 'the first column', 'int') dtype2 = DtypeSpec('column2', 'the second column', 'float') base = DatasetSpec('my first table', [dtype1, dtype2], attributes=self.attributes, data_type_def='SimpleTable') self.assertEqual(base['dtype'], [dtype1, dtype2]) self.assertEqual(base['doc'], 'my first table') dtype3 = DtypeSpec('column3', 'the third column', 'text') ext = DatasetSpec('my first table extension', [dtype3], data_type_inc=base, data_type_def='ExtendedTable') self.assertEqual(ext['dtype'], [dtype1, dtype2, dtype3]) self.assertEqual(ext['doc'], 'my first table extension') def test_datatype_table_extension_higher_precision(self): dtype1 = DtypeSpec('column1', 'the first column', 'int') dtype2 = DtypeSpec('column2', 'the second column', 'float32') base = DatasetSpec('my first table', [dtype1, dtype2], attributes=self.attributes, data_type_def='SimpleTable') self.assertEqual(base['dtype'], [dtype1, dtype2]) self.assertEqual(base['doc'], 'my first table') dtype3 = DtypeSpec('column2', 'the second column, with greater precision', 'float64') ext = DatasetSpec('my first table extension', [dtype3], data_type_inc=base, data_type_def='ExtendedTable') self.assertEqual(ext['dtype'], [dtype1, dtype3]) self.assertEqual(ext['doc'], 'my first table extension') def test_datatype_table_extension_lower_precision(self): dtype1 = DtypeSpec('column1', 'the first column', 'int') dtype2 = DtypeSpec('column2', 'the second column', 'float64') base = DatasetSpec('my first table', [dtype1, dtype2], attributes=self.attributes, data_type_def='SimpleTable') self.assertEqual(base['dtype'], [dtype1, dtype2]) self.assertEqual(base['doc'], 'my first table') dtype3 = DtypeSpec('column2', 'the second column, with greater precision', 'float32') with self.assertRaisesWith(ValueError, 'Cannot extend float64 to float32'): DatasetSpec('my first table extension', [dtype3], data_type_inc=base, data_type_def='ExtendedTable') def test_datatype_table_extension_diff_format(self): dtype1 = DtypeSpec('column1', 'the first column', 'int') dtype2 = DtypeSpec('column2', 'the second column', 'float64') base = DatasetSpec('my first table', [dtype1, dtype2], attributes=self.attributes, data_type_def='SimpleTable') self.assertEqual(base['dtype'], [dtype1, dtype2]) self.assertEqual(base['doc'], 'my first table') dtype3 = DtypeSpec('column2', 'the second column, with greater precision', 'int32') with self.assertRaisesWith(ValueError, 'Cannot extend float64 to int32'): DatasetSpec('my first table extension', [dtype3], data_type_inc=base, data_type_def='ExtendedTable') def test_data_type_property_value(self): """Test that the property data_type has the expected value""" test_cases = { ('Foo', 'Bar'): 'Bar', ('Foo', None): 'Foo', (None, 'Bar'): 'Bar', (None, None): None, } for (data_type_inc, data_type_def), data_type in test_cases.items(): with self.subTest(data_type_inc=data_type_inc, data_type_def=data_type_def, data_type=data_type): group = GroupSpec('A group', name='group', data_type_inc=data_type_inc, data_type_def=data_type_def) self.assertEqual(group.data_type, data_type) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test_dtype_spec.py0000644000655200065520000000644000000000000022632 0ustar00circlecicirclecifrom hdmf.spec import DtypeSpec, DtypeHelper, RefSpec from hdmf.testing import TestCase class DtypeSpecHelper(TestCase): def setUp(self): pass def test_recommended_dtypes(self): self.assertListEqual(DtypeHelper.recommended_primary_dtypes, list(DtypeHelper.primary_dtype_synonyms.keys())) def test_valid_primary_dtypes(self): a = set(list(DtypeHelper.primary_dtype_synonyms.keys()) + [vi for v in DtypeHelper.primary_dtype_synonyms.values() for vi in v]) self.assertSetEqual(a, DtypeHelper.valid_primary_dtypes) def test_simplify_cpd_type(self): compound_type = [DtypeSpec('test', 'test field', 'float'), DtypeSpec('test2', 'test field2', 'int')] expected_result = ['float', 'int'] result = DtypeHelper.simplify_cpd_type(compound_type) self.assertListEqual(result, expected_result) def test_simplify_cpd_type_ref(self): compound_type = [DtypeSpec('test', 'test field', 'float'), DtypeSpec('test2', 'test field2', RefSpec(target_type='MyType', reftype='object'))] expected_result = ['float', 'object'] result = DtypeHelper.simplify_cpd_type(compound_type) self.assertListEqual(result, expected_result) def test_check_dtype_ok(self): self.assertEqual('int', DtypeHelper.check_dtype('int')) def test_check_dtype_bad(self): msg = "dtype 'bad dtype' is not a valid primary data type." with self.assertRaisesRegex(ValueError, msg): DtypeHelper.check_dtype('bad dtype') def test_check_dtype_ref(self): refspec = RefSpec(target_type='target', reftype='object') self.assertIs(refspec, DtypeHelper.check_dtype(refspec)) class DtypeSpecTests(TestCase): def setUp(self): pass def test_constructor(self): spec = DtypeSpec('column1', 'an example column', 'int') self.assertEqual(spec.doc, 'an example column') self.assertEqual(spec.name, 'column1') self.assertEqual(spec.dtype, 'int') def test_build_spec(self): spec = DtypeSpec.build_spec({'doc': 'an example column', 'name': 'column1', 'dtype': 'int'}) self.assertEqual(spec.doc, 'an example column') self.assertEqual(spec.name, 'column1') self.assertEqual(spec.dtype, 'int') def test_invalid_refspec_dict(self): """Test missing or bad target key for RefSpec.""" msg = "'dtype' must have the key 'target_type'" with self.assertRaisesWith(ValueError, msg): DtypeSpec.assertValidDtype({'no target': 'test', 'reftype': 'object'}) def test_refspec_dtype(self): # just making sure this does not cause an error DtypeSpec('column1', 'an example column', RefSpec('TimeSeries', 'object')) def test_invalid_dtype(self): msg = "dtype 'bad dtype' is not a valid primary data type." with self.assertRaisesRegex(ValueError, msg): DtypeSpec('column1', 'an example column', dtype='bad dtype') def test_is_ref(self): spec = DtypeSpec('column1', 'an example column', RefSpec('TimeSeries', 'object')) self.assertTrue(DtypeSpec.is_ref(spec)) spec = DtypeSpec('column1', 'an example column', 'int') self.assertFalse(DtypeSpec.is_ref(spec)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test_group_spec.py0000644000655200065520000004173200000000000022644 0ustar00circlecicircleciimport json from hdmf.spec import GroupSpec, DatasetSpec, AttributeSpec from hdmf.testing import TestCase class GroupSpecTests(TestCase): def setUp(self): self.attributes = [ AttributeSpec('attribute1', 'my first attribute', 'text'), AttributeSpec('attribute2', 'my second attribute', 'text') ] self.dset1_attributes = [ AttributeSpec('attribute3', 'my third attribute', 'text'), AttributeSpec('attribute4', 'my fourth attribute', 'text') ] self.dset2_attributes = [ AttributeSpec('attribute5', 'my fifth attribute', 'text'), AttributeSpec('attribute6', 'my sixth attribute', 'text') ] self.datasets = [ DatasetSpec('my first dataset', 'int', name='dataset1', attributes=self.dset1_attributes, linkable=True), DatasetSpec('my second dataset', 'int', name='dataset2', attributes=self.dset2_attributes, linkable=True, data_type_def='VoltageArray') ] self.subgroups = [ GroupSpec('A test subgroup', name='subgroup1', linkable=False), GroupSpec('A test subgroup', name='subgroup2', linkable=False) ] self.ndt_attr_spec = AttributeSpec('data_type', 'the data type of this object', 'text', value='EphysData') self.ns_attr_spec = AttributeSpec('namespace', 'the namespace for the data type of this object', 'text', required=False) def test_constructor(self): spec = GroupSpec('A test group', name='root_constructor', groups=self.subgroups, datasets=self.datasets, attributes=self.attributes, linkable=False) self.assertFalse(spec['linkable']) self.assertListEqual(spec['attributes'], self.attributes) self.assertListEqual(spec['datasets'], self.datasets) self.assertNotIn('data_type_def', spec) self.assertIs(spec, self.subgroups[0].parent) self.assertIs(spec, self.subgroups[1].parent) self.assertIs(spec, self.attributes[0].parent) self.assertIs(spec, self.attributes[1].parent) self.assertIs(spec, self.datasets[0].parent) self.assertIs(spec, self.datasets[1].parent) json.dumps(spec) def test_constructor_datatype(self): spec = GroupSpec('A test group', name='root_constructor_datatype', datasets=self.datasets, attributes=self.attributes, linkable=False, data_type_def='EphysData') self.assertFalse(spec['linkable']) self.assertListEqual(spec['attributes'], self.attributes) self.assertListEqual(spec['datasets'], self.datasets) self.assertEqual(spec['data_type_def'], 'EphysData') self.assertIs(spec, self.attributes[0].parent) self.assertIs(spec, self.attributes[1].parent) self.assertIs(spec, self.datasets[0].parent) self.assertIs(spec, self.datasets[1].parent) self.assertEqual(spec.data_type_def, 'EphysData') self.assertIsNone(spec.data_type_inc) json.dumps(spec) def test_set_parent_exists(self): GroupSpec('A test group', name='root_constructor', groups=self.subgroups) msg = 'Cannot re-assign parent.' with self.assertRaisesWith(AttributeError, msg): self.subgroups[0].parent = self.subgroups[1] def test_set_dataset(self): spec = GroupSpec('A test group', name='root_test_set_dataset', linkable=False, data_type_def='EphysData') spec.set_dataset(self.datasets[0]) self.assertIs(spec, self.datasets[0].parent) def test_set_group(self): spec = GroupSpec('A test group', name='root_test_set_group', linkable=False, data_type_def='EphysData') spec.set_group(self.subgroups[0]) spec.set_group(self.subgroups[1]) self.assertListEqual(spec['groups'], self.subgroups) self.assertIs(spec, self.subgroups[0].parent) self.assertIs(spec, self.subgroups[1].parent) json.dumps(spec) def test_type_extension(self): spec = GroupSpec('A test group', name='parent_type', datasets=self.datasets, attributes=self.attributes, linkable=False, data_type_def='EphysData') dset1_attributes_ext = [ AttributeSpec('dset1_extra_attribute', 'an extra attribute for the first dataset', 'text') ] ext_datasets = [ DatasetSpec('my first dataset extension', 'int', name='dataset1', attributes=dset1_attributes_ext, linkable=True), ] ext_attributes = [ AttributeSpec('ext_extra_attribute', 'an extra attribute for the group', 'text'), ] ext = GroupSpec('A test group extension', name='child_type', datasets=ext_datasets, attributes=ext_attributes, linkable=False, data_type_inc=spec, data_type_def='SpikeData') ext_dset1 = ext.get_dataset('dataset1') ext_dset1_attrs = ext_dset1.attributes self.assertDictEqual(ext_dset1_attrs[0], dset1_attributes_ext[0]) self.assertDictEqual(ext_dset1_attrs[1], self.dset1_attributes[0]) self.assertDictEqual(ext_dset1_attrs[2], self.dset1_attributes[1]) self.assertEqual(ext.data_type_def, 'SpikeData') self.assertEqual(ext.data_type_inc, 'EphysData') ext_dset2 = ext.get_dataset('dataset2') self.maxDiff = None # this will suffice for now, assertDictEqual doesn't do deep equality checks self.assertEqual(str(ext_dset2), str(self.datasets[1])) self.assertAttributesEqual(ext_dset2, self.datasets[1]) # self.ns_attr_spec ndt_attr_spec = AttributeSpec('data_type', 'the data type of this object', # noqa: F841 'text', value='SpikeData') res_attrs = ext.attributes self.assertDictEqual(res_attrs[0], ext_attributes[0]) self.assertDictEqual(res_attrs[1], self.attributes[0]) self.assertDictEqual(res_attrs[2], self.attributes[1]) # test that inherited specs are tracked appropriate for d in self.datasets: with self.subTest(dataset=d.name): self.assertTrue(ext.is_inherited_spec(d)) self.assertFalse(spec.is_inherited_spec(d)) json.dumps(spec) def assertDatasetsEqual(self, spec1, spec2): spec1_dsets = spec1.datasets spec2_dsets = spec2.datasets if len(spec1_dsets) != len(spec2_dsets): raise AssertionError('different number of AttributeSpecs') else: for i in range(len(spec1_dsets)): self.assertAttributesEqual(spec1_dsets[i], spec2_dsets[i]) def assertAttributesEqual(self, spec1, spec2): spec1_attr = spec1.attributes spec2_attr = spec2.attributes if len(spec1_attr) != len(spec2_attr): raise AssertionError('different number of AttributeSpecs') else: for i in range(len(spec1_attr)): self.assertDictEqual(spec1_attr[i], spec2_attr[i]) def test_add_attribute(self): spec = GroupSpec('A test group', name='root_constructor', groups=self.subgroups, datasets=self.datasets, linkable=False) for attrspec in self.attributes: spec.add_attribute(**attrspec) self.assertListEqual(spec['attributes'], self.attributes) self.assertListEqual(spec['datasets'], self.datasets) self.assertNotIn('data_type_def', spec) self.assertIs(spec, self.subgroups[0].parent) self.assertIs(spec, self.subgroups[1].parent) self.assertIs(spec, spec.attributes[0].parent) self.assertIs(spec, spec.attributes[1].parent) self.assertIs(spec, self.datasets[0].parent) self.assertIs(spec, self.datasets[1].parent) json.dumps(spec) def test_update_attribute_spec(self): spec = GroupSpec('A test group', name='root_constructor', attributes=[AttributeSpec('attribute1', 'my first attribute', 'text'), AttributeSpec('attribute2', 'my second attribute', 'text')]) spec.set_attribute(AttributeSpec('attribute2', 'my second attribute', 'int', value=5)) res = spec.get_attribute('attribute2') self.assertEqual(res.value, 5) self.assertEqual(res.dtype, 'int') def test_path(self): GroupSpec('A test group', name='root_constructor', groups=self.subgroups, datasets=self.datasets, attributes=self.attributes, linkable=False) self.assertEqual(self.attributes[0].path, 'root_constructor/attribute1') self.assertEqual(self.datasets[0].path, 'root_constructor/dataset1') self.assertEqual(self.subgroups[0].path, 'root_constructor/subgroup1') def test_path_complicated(self): attribute = AttributeSpec('attribute1', 'my fifth attribute', 'text') dataset = DatasetSpec('my first dataset', 'int', name='dataset1', attributes=[attribute]) subgroup = GroupSpec('A subgroup', name='subgroup1', datasets=[dataset]) self.assertEqual(attribute.path, 'subgroup1/dataset1/attribute1') _ = GroupSpec('A test group', name='root', groups=[subgroup]) self.assertEqual(attribute.path, 'root/subgroup1/dataset1/attribute1') def test_path_no_name(self): attribute = AttributeSpec('attribute1', 'my fifth attribute', 'text') dataset = DatasetSpec('my first dataset', 'int', data_type_inc='DatasetType', attributes=[attribute]) subgroup = GroupSpec('A subgroup', data_type_def='GroupType', datasets=[dataset]) _ = GroupSpec('A test group', name='root', groups=[subgroup]) self.assertEqual(attribute.path, 'root/GroupType/DatasetType/attribute1') def test_data_type_property_value(self): """Test that the property data_type has the expected value""" test_cases = { ('Foo', 'Bar'): 'Bar', ('Foo', None): 'Foo', (None, 'Bar'): 'Bar', (None, None): None, } for (data_type_inc, data_type_def), data_type in test_cases.items(): with self.subTest(data_type_inc=data_type_inc, data_type_def=data_type_def, data_type=data_type): dataset = DatasetSpec('A dataset', 'int', name='dataset', data_type_inc=data_type_inc, data_type_def=data_type_def) self.assertEqual(dataset.data_type, data_type) def test_get_data_type_spec(self): expected = AttributeSpec('data_type', 'the data type of this object', 'text', value='MyType') self.assertDictEqual(GroupSpec.get_data_type_spec('MyType'), expected) def test_get_namespace_spec(self): expected = AttributeSpec('namespace', 'the namespace for the data type of this object', 'text', required=False) self.assertDictEqual(GroupSpec.get_namespace_spec(), expected) class TestNotAllowedConfig(TestCase): def test_no_name_no_def_no_inc(self): msg = ("Cannot create Group or Dataset spec with no name without specifying 'data_type_def' " "and/or 'data_type_inc'.") with self.assertRaisesWith(ValueError, msg): GroupSpec('A test group') def test_name_with_multiple(self): msg = ("Cannot give specific name to something that can exist multiple times: name='MyGroup', quantity='*'") with self.assertRaisesWith(ValueError, msg): GroupSpec('A test group', name='MyGroup', quantity='*') class TestResolveAttrs(TestCase): def setUp(self): self.def_group_spec = GroupSpec( doc='A test group', name='root', data_type_def='MyGroup', attributes=[AttributeSpec('attribute1', 'my first attribute', 'text'), AttributeSpec('attribute2', 'my second attribute', 'text')] ) self.inc_group_spec = GroupSpec( doc='A test group', name='root', data_type_inc='MyGroup', attributes=[AttributeSpec('attribute2', 'my second attribute', 'text', value='fixed'), AttributeSpec('attribute3', 'my third attribute', 'text', value='fixed')] ) self.inc_group_spec.resolve_spec(self.def_group_spec) def test_resolved(self): self.assertTupleEqual(self.inc_group_spec.attributes, ( AttributeSpec('attribute2', 'my second attribute', 'text', value='fixed'), AttributeSpec('attribute3', 'my third attribute', 'text', value='fixed'), AttributeSpec('attribute1', 'my first attribute', 'text') )) self.assertEqual(self.inc_group_spec.get_attribute('attribute1'), AttributeSpec('attribute1', 'my first attribute', 'text')) self.assertEqual(self.inc_group_spec.get_attribute('attribute2'), AttributeSpec('attribute2', 'my second attribute', 'text', value='fixed')) self.assertEqual(self.inc_group_spec.get_attribute('attribute3'), AttributeSpec('attribute3', 'my third attribute', 'text', value='fixed')) self.assertTrue(self.inc_group_spec.resolved) def test_is_inherited_spec(self): self.assertFalse(self.def_group_spec.is_inherited_spec('attribute1')) self.assertFalse(self.def_group_spec.is_inherited_spec('attribute2')) self.assertTrue(self.inc_group_spec.is_inherited_spec( AttributeSpec('attribute1', 'my first attribute', 'text') )) self.assertTrue(self.inc_group_spec.is_inherited_spec('attribute1')) self.assertTrue(self.inc_group_spec.is_inherited_spec('attribute2')) self.assertFalse(self.inc_group_spec.is_inherited_spec('attribute3')) self.assertFalse(self.inc_group_spec.is_inherited_spec('attribute4')) def test_is_overridden_spec(self): self.assertFalse(self.def_group_spec.is_overridden_spec('attribute1')) self.assertFalse(self.def_group_spec.is_overridden_spec('attribute2')) self.assertFalse(self.inc_group_spec.is_overridden_spec( AttributeSpec('attribute1', 'my first attribute', 'text') )) self.assertFalse(self.inc_group_spec.is_overridden_spec('attribute1')) self.assertTrue(self.inc_group_spec.is_overridden_spec('attribute2')) self.assertFalse(self.inc_group_spec.is_overridden_spec('attribute3')) self.assertFalse(self.inc_group_spec.is_overridden_spec('attribute4')) def test_is_inherited_attribute(self): self.assertFalse(self.def_group_spec.is_inherited_attribute('attribute1')) self.assertFalse(self.def_group_spec.is_inherited_attribute('attribute2')) self.assertTrue(self.inc_group_spec.is_inherited_attribute('attribute1')) self.assertTrue(self.inc_group_spec.is_inherited_attribute('attribute2')) self.assertFalse(self.inc_group_spec.is_inherited_attribute('attribute3')) with self.assertRaisesWith(ValueError, "Attribute 'attribute4' not found"): self.inc_group_spec.is_inherited_attribute('attribute4') def test_is_overridden_attribute(self): self.assertFalse(self.def_group_spec.is_overridden_attribute('attribute1')) self.assertFalse(self.def_group_spec.is_overridden_attribute('attribute2')) self.assertFalse(self.inc_group_spec.is_overridden_attribute('attribute1')) self.assertTrue(self.inc_group_spec.is_overridden_attribute('attribute2')) self.assertFalse(self.inc_group_spec.is_overridden_attribute('attribute3')) with self.assertRaisesWith(ValueError, "Attribute 'attribute4' not found"): self.inc_group_spec.is_overridden_attribute('attribute4') ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test_link_spec.py0000644000655200065520000000426300000000000022443 0ustar00circlecicircleciimport json from hdmf.spec import GroupSpec, LinkSpec from hdmf.testing import TestCase class LinkSpecTests(TestCase): def test_constructor(self): spec = LinkSpec( doc='A test link', target_type='Group1', quantity='+', name='Link1', ) self.assertEqual(spec.doc, 'A test link') self.assertEqual(spec.target_type, 'Group1') self.assertEqual(spec.data_type_inc, 'Group1') self.assertEqual(spec.quantity, '+') self.assertEqual(spec.name, 'Link1') json.dumps(spec) def test_constructor_target_spec_def(self): group_spec_def = GroupSpec( data_type_def='Group1', doc='A test group', ) spec = LinkSpec( doc='A test link', target_type=group_spec_def, ) self.assertEqual(spec.target_type, 'Group1') json.dumps(spec) def test_constructor_target_spec_inc(self): group_spec_inc = GroupSpec( data_type_inc='Group1', doc='A test group', ) msg = "'target_type' must be a string or a GroupSpec or DatasetSpec with a 'data_type_def' key." with self.assertRaisesWith(ValueError, msg): LinkSpec( doc='A test link', target_type=group_spec_inc, ) def test_constructor_defaults(self): spec = LinkSpec( doc='A test link', target_type='Group1', ) self.assertEqual(spec.quantity, 1) self.assertIsNone(spec.name) json.dumps(spec) def test_required_is_many(self): quantity_opts = ['?', 1, '*', '+'] is_required = [False, True, False, True] is_many = [False, False, True, True] for (quantity, req, many) in zip(quantity_opts, is_required, is_many): with self.subTest(quantity=quantity): spec = LinkSpec( doc='A test link', target_type='Group1', quantity=quantity, name='Link1', ) self.assertEqual(spec.required, req) self.assertEqual(spec.is_many(), many) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test_load_namespace.py0000644000655200065520000004033400000000000023426 0ustar00circlecicircleciimport json import os import ruamel.yaml as yaml from tempfile import gettempdir import warnings from hdmf.common import get_type_map from hdmf.spec import AttributeSpec, DatasetSpec, GroupSpec, SpecNamespace, NamespaceCatalog, NamespaceBuilder from hdmf.testing import TestCase, remove_test_file from tests.unit.utils import CustomGroupSpec, CustomDatasetSpec, CustomSpecNamespace class TestSpecLoad(TestCase): NS_NAME = 'test_ns' def setUp(self): self.attributes = [ AttributeSpec('attribute1', 'my first attribute', 'text'), AttributeSpec('attribute2', 'my second attribute', 'text') ] self.dset1_attributes = [ AttributeSpec('attribute3', 'my third attribute', 'text'), AttributeSpec('attribute4', 'my fourth attribute', 'text') ] self.dset2_attributes = [ AttributeSpec('attribute5', 'my fifth attribute', 'text'), AttributeSpec('attribute6', 'my sixth attribute', 'text') ] self.datasets = [ DatasetSpec('my first dataset', 'int', name='dataset1', attributes=self.dset1_attributes, linkable=True), DatasetSpec('my second dataset', 'int', name='dataset2', dims=(None, None), attributes=self.dset2_attributes, linkable=True, data_type_def='VoltageArray') ] self.spec = GroupSpec('A test group', name='root_constructor_datatype', datasets=self.datasets, attributes=self.attributes, linkable=False, data_type_def='EphysData') dset1_attributes_ext = [ AttributeSpec('dset1_extra_attribute', 'an extra attribute for the first dataset', 'text') ] self.ext_datasets = [ DatasetSpec('my first dataset extension', 'int', name='dataset1', attributes=dset1_attributes_ext, linkable=True), ] self.ext_attributes = [ AttributeSpec('ext_extra_attribute', 'an extra attribute for the group', 'text'), ] self.ext_spec = GroupSpec('A test group extension', name='root_constructor_datatype', datasets=self.ext_datasets, attributes=self.ext_attributes, linkable=False, data_type_inc='EphysData', data_type_def='SpikeData') to_dump = {'groups': [self.spec, self.ext_spec]} self.specs_path = 'test_load_namespace.specs.yaml' self.namespace_path = 'test_load_namespace.namespace.yaml' with open(self.specs_path, 'w') as tmp: yaml_obj = yaml.YAML(typ='safe', pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) ns_dict = { 'doc': 'a test namespace', 'name': self.NS_NAME, 'schema': [ {'source': self.specs_path} ], 'version': '0.1.0' } self.namespace = SpecNamespace.build_namespace(**ns_dict) to_dump = {'namespaces': [self.namespace]} with open(self.namespace_path, 'w') as tmp: yaml_obj = yaml.YAML(typ='safe', pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) self.ns_catalog = NamespaceCatalog() def tearDown(self): if os.path.exists(self.namespace_path): os.remove(self.namespace_path) if os.path.exists(self.specs_path): os.remove(self.specs_path) def test_inherited_attributes(self): self.ns_catalog.load_namespaces(self.namespace_path, resolve=True) ts_spec = self.ns_catalog.get_spec(self.NS_NAME, 'EphysData') es_spec = self.ns_catalog.get_spec(self.NS_NAME, 'SpikeData') ts_attrs = {s.name for s in ts_spec.attributes} es_attrs = {s.name for s in es_spec.attributes} for attr in ts_attrs: with self.subTest(attr=attr): self.assertIn(attr, es_attrs) # self.assertSetEqual(ts_attrs, es_attrs) ts_dsets = {s.name for s in ts_spec.datasets} es_dsets = {s.name for s in es_spec.datasets} for dset in ts_dsets: with self.subTest(dset=dset): self.assertIn(dset, es_dsets) # self.assertSetEqual(ts_dsets, es_dsets) def test_inherited_attributes_not_resolved(self): self.ns_catalog.load_namespaces(self.namespace_path, resolve=False) es_spec = self.ns_catalog.get_spec(self.NS_NAME, 'SpikeData') src_attrs = {s.name for s in self.ext_attributes} ext_attrs = {s.name for s in es_spec.attributes} self.assertSetEqual(src_attrs, ext_attrs) src_dsets = {s.name for s in self.ext_datasets} ext_dsets = {s.name for s in es_spec.datasets} self.assertSetEqual(src_dsets, ext_dsets) class TestSpecLoadEdgeCase(TestCase): def setUp(self): self.specs_path = 'test_load_namespace.specs.yaml' self.namespace_path = 'test_load_namespace.namespace.yaml' # write basically empty specs file to_dump = {'groups': []} with open(self.specs_path, 'w') as tmp: yaml_obj = yaml.YAML(typ='safe', pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) def tearDown(self): remove_test_file(self.namespace_path) remove_test_file(self.specs_path) def test_build_namespace_missing_version(self): """Test that building/creating a SpecNamespace without a version works but raises a warning.""" # create namespace without version key ns_dict = { 'doc': 'a test namespace', 'name': 'test_ns', 'schema': [ {'source': self.specs_path} ], } msg = ("Loaded namespace 'test_ns' is missing the required key 'version'. Version will be set to " "'%s'. Please notify the extension author." % SpecNamespace.UNVERSIONED) with self.assertWarnsWith(UserWarning, msg): namespace = SpecNamespace.build_namespace(**ns_dict) self.assertEqual(namespace.version, SpecNamespace.UNVERSIONED) def test_load_namespace_none_version(self): """Test that reading a namespace file without a version works but raises a warning.""" # create namespace with version key (remove it later) ns_dict = { 'doc': 'a test namespace', 'name': 'test_ns', 'schema': [ {'source': self.specs_path} ], 'version': '0.0.1' } namespace = SpecNamespace.build_namespace(**ns_dict) namespace['version'] = None # work around lack of setter to remove version key # write the namespace to file without version key to_dump = {'namespaces': [namespace]} with open(self.namespace_path, 'w') as tmp: yaml_obj = yaml.YAML(typ='safe', pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) # load the namespace from file ns_catalog = NamespaceCatalog() msg = ("Loaded namespace 'test_ns' is missing the required key 'version'. Version will be set to " "'%s'. Please notify the extension author." % SpecNamespace.UNVERSIONED) with self.assertWarnsWith(UserWarning, msg): ns_catalog.load_namespaces(self.namespace_path) self.assertEqual(ns_catalog.get_namespace('test_ns').version, SpecNamespace.UNVERSIONED) def test_load_namespace_unversioned_version(self): """Test that reading a namespace file with version=unversioned string works but raises a warning.""" # create namespace with version key (remove it later) ns_dict = { 'doc': 'a test namespace', 'name': 'test_ns', 'schema': [ {'source': self.specs_path} ], 'version': '0.0.1' } namespace = SpecNamespace.build_namespace(**ns_dict) namespace['version'] = str(SpecNamespace.UNVERSIONED) # work around lack of setter to remove version key # write the namespace to file without version key to_dump = {'namespaces': [namespace]} with open(self.namespace_path, 'w') as tmp: yaml_obj = yaml.YAML(typ='safe', pure=True) yaml_obj.default_flow_style = False yaml_obj.dump(json.loads(json.dumps(to_dump)), tmp) # load the namespace from file ns_catalog = NamespaceCatalog() msg = "Loaded namespace 'test_ns' is unversioned. Please notify the extension author." with self.assertWarnsWith(UserWarning, msg): ns_catalog.load_namespaces(self.namespace_path) self.assertEqual(ns_catalog.get_namespace('test_ns').version, SpecNamespace.UNVERSIONED) def test_missing_version_string(self): """Test that the constant variable representing a missing version has not changed.""" self.assertIsNone(SpecNamespace.UNVERSIONED) def test_get_namespace_missing_version(self): """Test that SpecNamespace.version returns the constant for a missing version if version gets removed.""" # create namespace with version key (remove it later) ns_dict = { 'doc': 'a test namespace', 'name': 'test_ns', 'schema': [ {'source': self.specs_path} ], 'version': '0.0.1' } namespace = SpecNamespace.build_namespace(**ns_dict) namespace['version'] = None # work around lack of setter to remove version key self.assertEqual(namespace.version, SpecNamespace.UNVERSIONED) class TestCatchDupNS(TestCase): def setUp(self): self.tempdir = gettempdir() self.ext_source1 = 'extension1.yaml' self.ns_path1 = 'namespace1.yaml' self.ext_source2 = 'extension2.yaml' self.ns_path2 = 'namespace2.yaml' def tearDown(self): for f in (self.ext_source1, self.ns_path1, self.ext_source2, self.ns_path2): remove_test_file(os.path.join(self.tempdir, f)) def test_catch_dup_name(self): ns_builder1 = NamespaceBuilder('Extension doc', "test_ext", version='0.1.0') ns_builder1.add_spec(self.ext_source1, GroupSpec('doc', data_type_def='MyType')) ns_builder1.export(self.ns_path1, outdir=self.tempdir) ns_builder2 = NamespaceBuilder('Extension doc', "test_ext", version='0.2.0') ns_builder2.add_spec(self.ext_source2, GroupSpec('doc', data_type_def='MyType')) ns_builder2.export(self.ns_path2, outdir=self.tempdir) ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(os.path.join(self.tempdir, self.ns_path1)) msg = "Ignoring cached namespace 'test_ext' version 0.2.0 because version 0.1.0 is already loaded." with self.assertWarnsRegex(UserWarning, msg): ns_catalog.load_namespaces(os.path.join(self.tempdir, self.ns_path2)) def test_catch_dup_name_same_version(self): ns_builder1 = NamespaceBuilder('Extension doc', "test_ext", version='0.1.0') ns_builder1.add_spec(self.ext_source1, GroupSpec('doc', data_type_def='MyType')) ns_builder1.export(self.ns_path1, outdir=self.tempdir) ns_builder2 = NamespaceBuilder('Extension doc', "test_ext", version='0.1.0') ns_builder2.add_spec(self.ext_source2, GroupSpec('doc', data_type_def='MyType')) ns_builder2.export(self.ns_path2, outdir=self.tempdir) ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(os.path.join(self.tempdir, self.ns_path1)) # no warning should be raised (but don't just check for 0 warnings -- warnings can come from other sources) msg = "Ignoring cached namespace 'test_ext' version 0.1.0 because version 0.1.0 is already loaded." with warnings.catch_warnings(record=True) as ws: ns_catalog.load_namespaces(os.path.join(self.tempdir, self.ns_path2)) for w in ws: self.assertTrue(str(w) != msg) class TestCustomSpecClasses(TestCase): def setUp(self): # noqa: C901 self.ns_catalog = NamespaceCatalog(CustomGroupSpec, CustomDatasetSpec, CustomSpecNamespace) hdmf_typemap = get_type_map() self.ns_catalog.merge(hdmf_typemap.namespace_catalog) def test_constructor_getters(self): self.assertEqual(self.ns_catalog.dataset_spec_cls, CustomDatasetSpec) self.assertEqual(self.ns_catalog.group_spec_cls, CustomGroupSpec) self.assertEqual(self.ns_catalog.spec_namespace_cls, CustomSpecNamespace) def test_load_namespaces(self): namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.namespace.yaml') namespace_deps = self.ns_catalog.load_namespaces(namespace_path) # test that the dependencies are correct, including dependencies of the dependencies expected = set(['Data', 'Container', 'DynamicTable', 'ElementIdentifiers', 'VectorData']) self.assertSetEqual(set(namespace_deps['test']['hdmf-common']), expected) # test that the types are loaded types = self.ns_catalog.get_types('test.base.yaml') expected = ('TestData', 'TestContainer', 'TestTable') self.assertTupleEqual(types, expected) # test that the namespace is correct and the types_key is updated for test ns test_namespace = self.ns_catalog.get_namespace('test') expected = {'doc': 'Test namespace', 'schema': [{'namespace': 'hdmf-common', 'my_data_types': ['Data', 'DynamicTable', 'Container']}, {'doc': 'This source module contains base data types.', 'source': 'test.base.yaml', 'title': 'Base data types'}], 'name': 'test', 'full_name': 'Test', 'version': '0.1.0', 'author': ['Test test'], 'contact': ['test@test.com']} self.assertDictEqual(test_namespace, expected) # test that the def_key is updated for test ns test_data_spec = self.ns_catalog.get_spec('test', 'TestData') self.assertTrue('my_data_type_def' in test_data_spec) self.assertTrue('my_data_type_inc' in test_data_spec) # test that the def_key is maintained for hdmf-common data_spec = self.ns_catalog.get_spec('hdmf-common', 'Data') self.assertTrue('data_type_def' in data_spec) def test_load_namespaces_ext(self): namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.namespace.yaml') self.ns_catalog.load_namespaces(namespace_path) ext_namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test-ext.namespace.yaml') ext_namespace_deps = self.ns_catalog.load_namespaces(ext_namespace_path) # test that the dependencies are correct, including dependencies of the dependencies expected_deps = set(['TestData', 'TestContainer', 'TestTable', 'Container', 'Data', 'DynamicTable', 'ElementIdentifiers', 'VectorData']) self.assertSetEqual(set(ext_namespace_deps['test-ext']['test']), expected_deps) def test_load_namespaces_bad_path(self): namespace_path = 'test.namespace.yaml' msg = "namespace file 'test.namespace.yaml' not found" with self.assertRaisesWith(IOError, msg): self.ns_catalog.load_namespaces(namespace_path) def test_load_namespaces_twice(self): namespace_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'test.namespace.yaml') namespace_deps1 = self.ns_catalog.load_namespaces(namespace_path) namespace_deps2 = self.ns_catalog.load_namespaces(namespace_path) self.assertDictEqual(namespace_deps1, namespace_deps2) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test_ref_spec.py0000644000655200065520000000127600000000000022263 0ustar00circlecicircleciimport json from hdmf.spec import RefSpec from hdmf.testing import TestCase class RefSpecTests(TestCase): def test_constructor(self): spec = RefSpec('TimeSeries', 'object') self.assertEqual(spec.target_type, 'TimeSeries') self.assertEqual(spec.reftype, 'object') json.dumps(spec) # to ensure there are no circular links def test_wrong_reference_type(self): with self.assertRaises(ValueError): RefSpec('TimeSeries', 'unknownreftype') def test_isregion(self): spec = RefSpec('TimeSeries', 'object') self.assertFalse(spec.is_region()) spec = RefSpec('Data', 'region') self.assertTrue(spec.is_region()) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test_spec_catalog.py0000644000655200065520000002421200000000000023114 0ustar00circlecicircleciimport copy from hdmf.spec import GroupSpec, DatasetSpec, AttributeSpec, SpecCatalog from hdmf.testing import TestCase class SpecCatalogTest(TestCase): def setUp(self): self.catalog = SpecCatalog() self.attributes = [ AttributeSpec('attribute1', 'my first attribute', 'text'), AttributeSpec('attribute2', 'my second attribute', 'text') ] self.spec = DatasetSpec('my first dataset', 'int', name='dataset1', dims=(None, None), attributes=self.attributes, linkable=False, data_type_def='EphysData') def test_register_spec(self): self.catalog.register_spec(self.spec, 'test.yaml') result = self.catalog.get_spec('EphysData') self.assertIs(result, self.spec) def test_hierarchy(self): spikes_spec = DatasetSpec('my extending dataset', 'int', data_type_inc='EphysData', data_type_def='SpikeData') lfp_spec = DatasetSpec('my second extending dataset', 'int', data_type_inc='EphysData', data_type_def='LFPData') self.catalog.register_spec(self.spec, 'test.yaml') self.catalog.register_spec(spikes_spec, 'test.yaml') self.catalog.register_spec(lfp_spec, 'test.yaml') spike_hierarchy = self.catalog.get_hierarchy('SpikeData') lfp_hierarchy = self.catalog.get_hierarchy('LFPData') ephys_hierarchy = self.catalog.get_hierarchy('EphysData') self.assertTupleEqual(spike_hierarchy, ('SpikeData', 'EphysData')) self.assertTupleEqual(lfp_hierarchy, ('LFPData', 'EphysData')) self.assertTupleEqual(ephys_hierarchy, ('EphysData',)) def test_subtypes(self): """ -BaseContainer--+-->AContainer--->ADContainer | +-->BContainer """ base_spec = GroupSpec(doc='Base container', data_type_def='BaseContainer') acontainer = GroupSpec(doc='AContainer', data_type_inc='BaseContainer', data_type_def='AContainer') adcontainer = GroupSpec(doc='ADContainer', data_type_inc='AContainer', data_type_def='ADContainer') bcontainer = GroupSpec(doc='BContainer', data_type_inc='BaseContainer', data_type_def='BContainer') self.catalog.register_spec(base_spec, 'test.yaml') self.catalog.register_spec(acontainer, 'test.yaml') self.catalog.register_spec(adcontainer, 'test.yaml') self.catalog.register_spec(bcontainer, 'test.yaml') base_spec_subtypes = self.catalog.get_subtypes('BaseContainer') base_spec_subtypes = tuple(sorted(base_spec_subtypes)) # Sort so we have a guaranteed order for comparison acontainer_subtypes = self.catalog.get_subtypes('AContainer') bcontainer_substypes = self.catalog.get_subtypes('BContainer') adcontainer_subtypes = self.catalog.get_subtypes('ADContainer') self.assertTupleEqual(adcontainer_subtypes, ()) self.assertTupleEqual(bcontainer_substypes, ()) self.assertTupleEqual(acontainer_subtypes, ('ADContainer',)) self.assertTupleEqual(base_spec_subtypes, ('AContainer', 'ADContainer', 'BContainer')) def test_subtypes_norecursion(self): """ -BaseContainer--+-->AContainer--->ADContainer | +-->BContainer """ base_spec = GroupSpec(doc='Base container', data_type_def='BaseContainer') acontainer = GroupSpec(doc='AContainer', data_type_inc='BaseContainer', data_type_def='AContainer') adcontainer = GroupSpec(doc='ADContainer', data_type_inc='AContainer', data_type_def='ADContainer') bcontainer = GroupSpec(doc='BContainer', data_type_inc='BaseContainer', data_type_def='BContainer') self.catalog.register_spec(base_spec, 'test.yaml') self.catalog.register_spec(acontainer, 'test.yaml') self.catalog.register_spec(adcontainer, 'test.yaml') self.catalog.register_spec(bcontainer, 'test.yaml') base_spec_subtypes = self.catalog.get_subtypes('BaseContainer', recursive=False) base_spec_subtypes = tuple(sorted(base_spec_subtypes)) # Sort so we have a guaranteed order for comparison acontainer_subtypes = self.catalog.get_subtypes('AContainer', recursive=False) bcontainer_substypes = self.catalog.get_subtypes('BContainer', recursive=False) adcontainer_subtypes = self.catalog.get_subtypes('ADContainer', recursive=False) self.assertTupleEqual(adcontainer_subtypes, ()) self.assertTupleEqual(bcontainer_substypes, ()) self.assertTupleEqual(acontainer_subtypes, ('ADContainer',)) self.assertTupleEqual(base_spec_subtypes, ('AContainer', 'BContainer')) def test_subtypes_unknown_type(self): subtypes_of_bad_type = self.catalog.get_subtypes('UnknownType') self.assertTupleEqual(subtypes_of_bad_type, ()) def test_get_spec_source_file(self): spikes_spec = GroupSpec('test group', data_type_def='SpikeData') source_file_path = '/test/myt/test.yaml' self.catalog.auto_register(spikes_spec, source_file_path) recorded_source_file_path = self.catalog.get_spec_source_file('SpikeData') self.assertEqual(recorded_source_file_path, source_file_path) def test_get_full_hierarchy(self): """ BaseContainer--+-->AContainer--->ADContainer | +-->BContainer Expected output: >> print(json.dumps(full_hierarchy, indent=4)) >> { >> "BaseContainer": { >> "AContainer": { >> "ADContainer": {} >> }, >> "BContainer": {} >> } """ base_spec = GroupSpec(doc='Base container', data_type_def='BaseContainer') acontainer = GroupSpec(doc='AContainer', data_type_inc='BaseContainer', data_type_def='AContainer') adcontainer = GroupSpec(doc='ADContainer', data_type_inc='AContainer', data_type_def='ADContainer') bcontainer = GroupSpec(doc='BContainer', data_type_inc='BaseContainer', data_type_def='BContainer') self.catalog.register_spec(base_spec, 'test.yaml') self.catalog.register_spec(acontainer, 'test.yaml') self.catalog.register_spec(adcontainer, 'test.yaml') self.catalog.register_spec(bcontainer, 'test.yaml') full_hierarchy = self.catalog.get_full_hierarchy() expected_hierarchy = { "BaseContainer": { "AContainer": { "ADContainer": {} }, "BContainer": {} } } self.assertDictEqual(full_hierarchy, expected_hierarchy) def test_copy_spec_catalog(self): # Register the spec first self.catalog.register_spec(self.spec, 'test.yaml') result = self.catalog.get_spec('EphysData') self.assertIs(result, self.spec) # Now test the copy re = copy.copy(self.catalog) self.assertTupleEqual(self.catalog.get_registered_types(), re.get_registered_types()) def test_deepcopy_spec_catalog(self): # Register the spec first self.catalog.register_spec(self.spec, 'test.yaml') result = self.catalog.get_spec('EphysData') self.assertIs(result, self.spec) # Now test the copy re = copy.deepcopy(self.catalog) self.assertTupleEqual(self.catalog.get_registered_types(), re.get_registered_types()) def test_catch_duplicate_spec_nested(self): spec1 = GroupSpec( data_type_def='Group1', doc='This is my new group 1', ) spec2 = GroupSpec( data_type_def='Group2', doc='This is my new group 2', groups=[spec1], # nested definition ) source = 'test_extension.yaml' self.catalog.register_spec(spec1, source) self.catalog.register_spec(spec2, source) # this is OK because Group1 is the same spec ret = self.catalog.get_spec('Group1') self.assertIs(ret, spec1) def test_catch_duplicate_spec_different(self): spec1 = GroupSpec( data_type_def='Group1', doc='This is my new group 1', ) spec2 = GroupSpec( data_type_def='Group1', doc='This is my other group 1', ) source = 'test_extension.yaml' self.catalog.register_spec(spec1, source) msg = "'Group1' - cannot overwrite existing specification" with self.assertRaisesWith(ValueError, msg): self.catalog.register_spec(spec2, source) def test_catch_duplicate_spec_different_source(self): spec1 = GroupSpec( data_type_def='Group1', doc='This is my new group 1', ) spec2 = GroupSpec( data_type_def='Group1', doc='This is my new group 1', ) source1 = 'test_extension1.yaml' source2 = 'test_extension2.yaml' self.catalog.register_spec(spec1, source1) msg = "'Group1' - cannot overwrite existing specification" with self.assertRaisesWith(ValueError, msg): self.catalog.register_spec(spec2, source2) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/spec_tests/test_spec_write.py0000644000655200065520000001772100000000000022643 0ustar00circlecicircleciimport datetime import os from hdmf.spec.namespace import SpecNamespace, NamespaceCatalog from hdmf.spec.spec import GroupSpec from hdmf.spec.write import NamespaceBuilder, YAMLSpecWriter, export_spec from hdmf.testing import TestCase class TestSpec(TestCase): def setUp(self): # create a builder for the namespace self.ns_name = "mylab" self.date = datetime.datetime.now() self.ns_builder = NamespaceBuilder(doc="mydoc", name=self.ns_name, full_name="My Laboratory", version="0.0.1", author="foo", contact="foo@bar.com", namespace_cls=SpecNamespace, date=self.date) # create extensions ext1 = GroupSpec('A custom DataSeries interface', attributes=[], datasets=[], groups=[], data_type_inc=None, data_type_def='MyDataSeries') ext2 = GroupSpec('An extension of a DataSeries interface', attributes=[], datasets=[], groups=[], data_type_inc='MyDataSeries', data_type_def='MyExtendedMyDataSeries') ext2.add_dataset(doc='test', dtype='float', name='testdata') self.data_types = [ext1, ext2] # add the extension self.ext_source_path = 'mylab.extensions.yaml' self.namespace_path = 'mylab.namespace.yaml' def _test_extensions_file(self): with open(self.ext_source_path, 'r') as file: match_str = \ """groups: - data_type_def: MyDataSeries doc: A custom DataSeries interface - data_type_def: MyExtendedMyDataSeries data_type_inc: MyDataSeries doc: An extension of a DataSeries interface datasets: - name: testdata dtype: float doc: test """ # noqa: E122 nsstr = file.read() self.assertEqual(nsstr, match_str) def _test_namespace_file(self): with open(self.namespace_path, 'r') as file: match_str = \ """namespaces: - author: foo contact: foo@bar.com date: '%s' doc: mydoc full_name: My Laboratory name: mylab schema: - doc: Extensions for my lab source: mylab.extensions.yaml title: Extensions for my lab version: 0.0.1 """ % self.date.isoformat() # noqa: E122 nsstr = file.read() self.assertEqual(nsstr, match_str) class TestNamespaceBuilder(TestSpec): NS_NAME = 'test_ns' def setUp(self): super().setUp() for data_type in self.data_types: self.ns_builder.add_spec(source=self.ext_source_path, spec=data_type) self.ns_builder.add_source(source=self.ext_source_path, doc='Extensions for my lab', title='My lab extensions') self.ns_builder.export(self.namespace_path) def tearDown(self): if os.path.exists(self.ext_source_path): os.remove(self.ext_source_path) if os.path.exists(self.namespace_path): os.remove(self.namespace_path) def test_export_namespace(self): self._test_namespace_file() self._test_extensions_file() def test_read_namespace(self): ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(self.namespace_path, resolve=True) loaded_ns = ns_catalog.get_namespace(self.ns_name) self.assertEqual(loaded_ns.doc, "mydoc") self.assertEqual(loaded_ns.author, "foo") self.assertEqual(loaded_ns.contact, "foo@bar.com") self.assertEqual(loaded_ns.full_name, "My Laboratory") self.assertEqual(loaded_ns.name, "mylab") self.assertEqual(loaded_ns.date, self.date.isoformat()) self.assertDictEqual(loaded_ns.schema[0], {'doc': 'Extensions for my lab', 'source': 'mylab.extensions.yaml', 'title': 'Extensions for my lab'}) self.assertEqual(loaded_ns.version, "0.0.1") def test_get_source_files(self): ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(self.namespace_path, resolve=True) loaded_ns = ns_catalog.get_namespace(self.ns_name) self.assertListEqual(loaded_ns.get_source_files(), ['mylab.extensions.yaml']) def test_get_source_description(self): ns_catalog = NamespaceCatalog() ns_catalog.load_namespaces(self.namespace_path, resolve=True) loaded_ns = ns_catalog.get_namespace(self.ns_name) descr = loaded_ns.get_source_description('mylab.extensions.yaml') self.assertDictEqual(descr, {'doc': 'Extensions for my lab', 'source': 'mylab.extensions.yaml', 'title': 'Extensions for my lab'}) def test_missing_version(self): """Test that creating a namespace builder without a version raises an error.""" msg = "Namespace '%s' missing key 'version'. Please specify a version for the extension." % self.ns_name with self.assertRaisesWith(ValueError, msg): self.ns_builder = NamespaceBuilder(doc="mydoc", name=self.ns_name, full_name="My Laboratory", author="foo", contact="foo@bar.com", namespace_cls=SpecNamespace, date=self.date) class TestYAMLSpecWrite(TestSpec): def setUp(self): super().setUp() for data_type in self.data_types: self.ns_builder.add_spec(source=self.ext_source_path, spec=data_type) self.ns_builder.add_source(source=self.ext_source_path, doc='Extensions for my lab', title='My lab extensions') def tearDown(self): if os.path.exists(self.ext_source_path): os.remove(self.ext_source_path) if os.path.exists(self.namespace_path): os.remove(self.namespace_path) def test_init(self): temp = YAMLSpecWriter('.') self.assertEqual(temp._YAMLSpecWriter__outdir, '.') def test_write_namespace(self): temp = YAMLSpecWriter() self.ns_builder.export(self.namespace_path, writer=temp) self._test_namespace_file() self._test_extensions_file() def test_get_name(self): self.assertEqual(self.ns_name, self.ns_builder.name) class TestExportSpec(TestSpec): def test_export(self): """Test that export_spec writes the correct files.""" export_spec(self.ns_builder, self.data_types, '.') self._test_namespace_file() self._test_extensions_file() def tearDown(self): if os.path.exists(self.ext_source_path): os.remove(self.ext_source_path) if os.path.exists(self.namespace_path): os.remove(self.namespace_path) def _test_namespace_file(self): with open(self.namespace_path, 'r') as file: match_str = \ """namespaces: - author: foo contact: foo@bar.com date: '%s' doc: mydoc full_name: My Laboratory name: mylab schema: - source: mylab.extensions.yaml version: 0.0.1 """ % self.date.isoformat() # noqa: E122 nsstr = file.read() self.assertEqual(nsstr, match_str) def test_missing_data_types(self): """Test that calling export_spec on a namespace builder without data types raises a warning.""" with self.assertWarnsWith(UserWarning, 'No data types specified. Exiting.'): export_spec(self.ns_builder, [], '.') ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/test_container.py0000644000655200065520000005075700000000000020313 0ustar00circlecicircleciimport numpy as np from hdmf.container import AbstractContainer, Container, Data from hdmf.testing import TestCase from hdmf.utils import docval class Subcontainer(Container): pass class TestContainer(TestCase): def test_constructor(self): """Test that constructor properly sets parent and both child and parent have an object_id """ parent_obj = Container('obj1') child_obj = Container.__new__(Container, parent=parent_obj) self.assertIs(child_obj.parent, parent_obj) self.assertIs(parent_obj.children[0], child_obj) self.assertIsNotNone(parent_obj.object_id) self.assertIsNotNone(child_obj.object_id) def test_constructor_object_id_none(self): """Test that setting object_id to None in __new__ is OK and the object ID is set on get """ parent_obj = Container('obj1') child_obj = Container.__new__(Container, parent=parent_obj, object_id=None) self.assertIsNotNone(child_obj.object_id) def test_set_parent(self): """Test that parent setter properly sets parent """ parent_obj = Container('obj1') child_obj = Container('obj2') child_obj.parent = parent_obj self.assertIs(child_obj.parent, parent_obj) self.assertIs(parent_obj.children[0], child_obj) def test_set_parent_overwrite(self): """Test that parent setter properly blocks overwriting """ parent_obj = Container('obj1') child_obj = Container('obj2') child_obj.parent = parent_obj self.assertIs(parent_obj.children[0], child_obj) another_obj = Container('obj3') with self.assertRaisesWith(ValueError, 'Cannot reassign parent to Container: %s. Parent is already: %s.' % (repr(child_obj), repr(child_obj.parent))): child_obj.parent = another_obj self.assertIs(child_obj.parent, parent_obj) self.assertIs(parent_obj.children[0], child_obj) def test_set_parent_overwrite_proxy(self): """Test that parent setter properly blocks overwriting with proxy/object """ child_obj = Container('obj2') child_obj.parent = object() with self.assertRaisesRegex(ValueError, r"Got None for parent of '[^/]+' - cannot overwrite Proxy with NoneType"): child_obj.parent = None def test_slash_restriction(self): self.assertRaises(ValueError, Container, 'bad/name') def test_set_modified_parent(self): """Test that set modified properly sets parent modified """ parent_obj = Container('obj1') child_obj = Container('obj2') child_obj.parent = parent_obj parent_obj.set_modified(False) child_obj.set_modified(False) self.assertFalse(child_obj.parent.modified) child_obj.set_modified() self.assertTrue(child_obj.parent.modified) def test_add_child(self): """Test that add child creates deprecation warning and also properly sets child's parent and modified """ parent_obj = Container('obj1') child_obj = Container('obj2') parent_obj.set_modified(False) with self.assertWarnsWith(DeprecationWarning, 'add_child is deprecated. Set the parent attribute instead.'): parent_obj.add_child(child_obj) self.assertIs(child_obj.parent, parent_obj) self.assertTrue(parent_obj.modified) self.assertIs(parent_obj.children[0], child_obj) def test_set_parent_exists(self): """Test that setting a parent a second time does nothing """ parent_obj = Container('obj1') child_obj = Container('obj2') child_obj3 = Container('obj3') child_obj.parent = parent_obj child_obj.parent = parent_obj child_obj3.parent = parent_obj self.assertEqual(len(parent_obj.children), 2) self.assertIs(parent_obj.children[0], child_obj) self.assertIs(parent_obj.children[1], child_obj3) def test_reassign_container_source(self): """Test that reassign container source throws error """ parent_obj = Container('obj1') parent_obj.container_source = 'a source' with self.assertRaisesWith(Exception, 'cannot reassign container_source'): parent_obj.container_source = 'some other source' def test_repr(self): parent_obj = Container('obj1') self.assertRegex(str(parent_obj), r"obj1 hdmf.container.Container at 0x\d+") def test_type_hierarchy(self): self.assertEqual(Container.type_hierarchy(), (Container, AbstractContainer, object)) self.assertEqual(Subcontainer.type_hierarchy(), (Subcontainer, Container, AbstractContainer, object)) def test_generate_new_id_parent(self): """Test that generate_new_id sets a new ID on the container and its children and sets modified on all.""" parent_obj = Container('obj1') child_obj = Container('obj2') child_obj.parent = parent_obj old_parent_id = parent_obj.object_id old_child_id = child_obj.object_id parent_obj.set_modified(False) child_obj.set_modified(False) parent_obj.generate_new_id() self.assertNotEqual(old_parent_id, parent_obj.object_id) self.assertNotEqual(old_child_id, child_obj.object_id) self.assertTrue(parent_obj.modified) self.assertTrue(child_obj.modified) def test_generate_new_id_child(self): """Test that generate_new_id sets a new ID on the container and not its parent and sets modified on both.""" parent_obj = Container('obj1') child_obj = Container('obj2') child_obj.parent = parent_obj old_parent_id = parent_obj.object_id old_child_id = child_obj.object_id parent_obj.set_modified(False) child_obj.set_modified(False) child_obj.generate_new_id() self.assertEqual(old_parent_id, parent_obj.object_id) self.assertNotEqual(old_child_id, child_obj.object_id) self.assertTrue(parent_obj.modified) self.assertTrue(child_obj.modified) def test_generate_new_id_parent_no_recurse(self): """Test that generate_new_id(recurse=False) sets a new ID on the container and not its children.""" parent_obj = Container('obj1') child_obj = Container('obj2') child_obj.parent = parent_obj old_parent_id = parent_obj.object_id old_child_id = child_obj.object_id parent_obj.set_modified(False) child_obj.set_modified(False) parent_obj.generate_new_id(recurse=False) self.assertNotEqual(old_parent_id, parent_obj.object_id) self.assertEqual(old_child_id, child_obj.object_id) self.assertTrue(parent_obj.modified) self.assertFalse(child_obj.modified) def test_remove_child(self): """Test that removing a child removes only the child. """ parent_obj = Container('obj1') child_obj = Container('obj2') child_obj3 = Container('obj3') child_obj.parent = parent_obj child_obj3.parent = parent_obj parent_obj._remove_child(child_obj) self.assertTupleEqual(parent_obj.children, (child_obj3, )) self.assertTrue(parent_obj.modified) self.assertTrue(child_obj.modified) def test_remove_child_noncontainer(self): """Test that removing a non-Container child raises an error. """ msg = "Cannot remove non-AbstractContainer object from children." with self.assertRaisesWith(ValueError, msg): Container('obj1')._remove_child(object()) def test_remove_child_nonchild(self): """Test that removing a non-Container child raises an error. """ msg = "Container 'dummy' is not a child of Container 'obj1'." with self.assertRaisesWith(ValueError, msg): Container('obj1')._remove_child(Container('dummy')) class TestData(TestCase): def test_constructor_scalar(self): """Test that constructor works correctly on scalar data """ data_obj = Data('my_data', 'foobar') self.assertEqual(data_obj.data, 'foobar') def test_bool_true(self): """Test that __bool__ method works correctly on data with len """ data_obj = Data('my_data', [1, 2, 3, 4, 5]) self.assertTrue(data_obj) def test_bool_false(self): """Test that __bool__ method works correctly on empty data """ data_obj = Data('my_data', []) self.assertFalse(data_obj) def test_shape_nparray(self): """ Test that shape works for np.array """ data_obj = Data('my_data', np.arange(10).reshape(2, 5)) self.assertTupleEqual(data_obj.shape, (2, 5)) def test_shape_list(self): """ Test that shape works for np.array """ data_obj = Data('my_data', [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]) self.assertTupleEqual(data_obj.shape, (2, 5)) class TestAbstractContainerFieldsConf(TestCase): def test_bad_fields_type(self): msg = "'__fields__' must be of type tuple" with self.assertRaisesWith(TypeError, msg): class BadFieldsType(AbstractContainer): __fields__ = {'name': 'field1'} def test_bad_field_conf_key(self): msg = "Unrecognized key 'child' in __fields__ config 'field1' on BadFieldConfKey" with self.assertRaisesWith(ValueError, msg): class BadFieldConfKey(AbstractContainer): __fields__ = ({'name': 'field1', 'child': True}, ) def test_bad_field_missing_name(self): msg = "must specify 'name' if using dict in __fields__" with self.assertRaisesWith(ValueError, msg): class BadFieldConfKey(AbstractContainer): __fields__ = ({'child': True}, ) @staticmethod def find_all_properties(klass): return [attr for attr in dir(klass) if isinstance(getattr(klass, attr, None), property)] def test_empty_fields(self): class EmptyFields(AbstractContainer): __fields__ = tuple() self.assertTupleEqual(EmptyFields.__fields__, tuple()) self.assertTupleEqual(EmptyFields._get_fields(), tuple()) self.assertTupleEqual(EmptyFields.get_fields_conf(), tuple()) props = TestAbstractContainerFieldsConf.find_all_properties(EmptyFields) expected = ['children', 'container_source', 'fields', 'modified', 'name', 'object_id', 'parent'] self.assertListEqual(props, expected) def test_named_fields(self): class NamedFields(AbstractContainer): __fields__ = ('field1', 'field2') @docval({'name': 'field2', 'doc': 'field2 doc', 'type': str}) def __init__(self, **kwargs): super().__init__('test name') self.field2 = kwargs['field2'] self.assertTupleEqual(NamedFields.__fields__, ('field1', 'field2')) self.assertIs(NamedFields._get_fields(), NamedFields.__fields__) expected = ({'doc': None, 'name': 'field1'}, {'doc': 'field2 doc', 'name': 'field2'}) self.assertTupleEqual(NamedFields.get_fields_conf(), expected) props = TestAbstractContainerFieldsConf.find_all_properties(NamedFields) expected = ['children', 'container_source', 'field1', 'field2', 'fields', 'modified', 'name', 'object_id', 'parent'] self.assertListEqual(props, expected) f1_doc = getattr(NamedFields, 'field1').__doc__ self.assertIsNone(f1_doc) f2_doc = getattr(NamedFields, 'field2').__doc__ self.assertEqual(f2_doc, 'field2 doc') obj = NamedFields('field2 value') self.assertIsNone(obj.field1) self.assertEqual(obj.field2, 'field2 value') obj.field1 = 'field1 value' msg = "can't set attribute 'field2' -- already set" with self.assertRaisesWith(AttributeError, msg): obj.field2 = 'field2 value' obj.field2 = None # None value does nothing self.assertEqual(obj.field2, 'field2 value') def test_with_doc(self): """Test that __fields__ related attributes are set correctly. Also test that the docstring for fields are not overridden by the docstring in the docval of __init__ if a doc is provided in cls.__fields__. """ class NamedFieldsWithDoc(AbstractContainer): __fields__ = ({'name': 'field1', 'doc': 'field1 orig doc'}, {'name': 'field2', 'doc': 'field2 orig doc'}) @docval({'name': 'field2', 'doc': 'field2 doc', 'type': str}) def __init__(self, **kwargs): super().__init__('test name') self.field2 = kwargs['field2'] expected = ({'doc': 'field1 orig doc', 'name': 'field1'}, {'doc': 'field2 orig doc', 'name': 'field2'}) self.assertTupleEqual(NamedFieldsWithDoc.get_fields_conf(), expected) f1_doc = getattr(NamedFieldsWithDoc, 'field1').__doc__ self.assertEqual(f1_doc, 'field1 orig doc') f2_doc = getattr(NamedFieldsWithDoc, 'field2').__doc__ self.assertEqual(f2_doc, 'field2 orig doc') def test_not_settable(self): """Test that __fields__ related attributes are set correctly. Also test that the docstring for fields are not overridden by the docstring in the docval of __init__ if a doc is provided in cls.__fields__. """ class NamedFieldsNotSettable(AbstractContainer): __fields__ = ({'name': 'field1', 'settable': True}, {'name': 'field2', 'settable': False}) expected = ({'doc': None, 'name': 'field1', 'settable': True}, {'doc': None, 'name': 'field2', 'settable': False}) self.assertTupleEqual(NamedFieldsNotSettable.get_fields_conf(), expected) obj = NamedFieldsNotSettable('test name') obj.field1 = 'field1 value' with self.assertRaisesWith(AttributeError, "can't set attribute"): obj.field2 = 'field2 value' def test_inheritance(self): class NamedFields(AbstractContainer): __fields__ = ({'name': 'field1', 'doc': 'field1 doc', 'settable': False}, ) class NamedFieldsChild(NamedFields): __fields__ = ({'name': 'field2'}, ) self.assertTupleEqual(NamedFieldsChild.__fields__, ('field1', 'field2')) self.assertIs(NamedFieldsChild._get_fields(), NamedFieldsChild.__fields__) expected = ({'doc': 'field1 doc', 'name': 'field1', 'settable': False}, {'doc': None, 'name': 'field2'}) self.assertTupleEqual(NamedFieldsChild.get_fields_conf(), expected) props = TestAbstractContainerFieldsConf.find_all_properties(NamedFieldsChild) expected = ['children', 'container_source', 'field1', 'field2', 'fields', 'modified', 'name', 'object_id', 'parent'] self.assertListEqual(props, expected) def test_inheritance_override(self): class NamedFields(AbstractContainer): __fields__ = ({'name': 'field1'}, ) class NamedFieldsChild(NamedFields): __fields__ = ({'name': 'field1', 'doc': 'overridden field', 'settable': False}, ) self.assertEqual(NamedFieldsChild._get_fields(), ('field1', )) ret = NamedFieldsChild.get_fields_conf() self.assertEqual(ret[0], {'name': 'field1', 'doc': 'overridden field', 'settable': False}) # obj = NamedFieldsChild('test name') # with self.assertRaisesWith(AttributeError, "can't set attribute"): # obj.field1 = 'field1 value' def test_mult_inheritance_base_mixin(self): class NamedFields(AbstractContainer): __fields__ = ({'name': 'field1', 'doc': 'field1 doc', 'settable': False}, ) class BlankMixin: pass class NamedFieldsChild(NamedFields, BlankMixin): __fields__ = ({'name': 'field2'}, ) self.assertTupleEqual(NamedFieldsChild.__fields__, ('field1', 'field2')) self.assertIs(NamedFieldsChild._get_fields(), NamedFieldsChild.__fields__) def test_mult_inheritance_base_container(self): class NamedFields(AbstractContainer): __fields__ = ({'name': 'field1', 'doc': 'field1 doc', 'settable': False}, ) class BlankMixin: pass class NamedFieldsChild(BlankMixin, NamedFields): __fields__ = ({'name': 'field2'}, ) self.assertTupleEqual(NamedFieldsChild.__fields__, ('field1', 'field2')) self.assertIs(NamedFieldsChild._get_fields(), NamedFieldsChild.__fields__) class TestContainerFieldsConf(TestCase): def test_required_name(self): class ContainerRequiredName(Container): __fields__ = ({'name': 'field1', 'required_name': 'field1 value'}, ) @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}) def __init__(self, **kwargs): super().__init__('test name') self.field1 = kwargs['field1'] msg = ("Field 'field1' on ContainerRequiredName has a required name and must be a subclass of " "AbstractContainer.") with self.assertRaisesWith(ValueError, msg): ContainerRequiredName('field1 value') obj1 = Container('test container') msg = "Field 'field1' on ContainerRequiredName must be named 'field1 value'." with self.assertRaisesWith(ValueError, msg): ContainerRequiredName(obj1) obj2 = Container('field1 value') obj3 = ContainerRequiredName(obj2) self.assertIs(obj3.field1, obj2) obj4 = ContainerRequiredName() self.assertIsNone(obj4.field1) def test_child(self): class ContainerWithChild(Container): __fields__ = ({'name': 'field1', 'child': True}, ) @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}) def __init__(self, **kwargs): super().__init__('test name') self.field1 = kwargs['field1'] child_obj1 = Container('test child 1') obj1 = ContainerWithChild(child_obj1) self.assertIs(child_obj1.parent, obj1) child_obj2 = Container('test child 2') obj3 = ContainerWithChild((child_obj1, child_obj2)) self.assertIs(child_obj1.parent, obj1) # child1 parent is already set self.assertIs(child_obj2.parent, obj3) # child1 parent is already set child_obj3 = Container('test child 3') obj4 = ContainerWithChild({'test child 3': child_obj3}) self.assertIs(child_obj3.parent, obj4) obj2 = ContainerWithChild() self.assertIsNone(obj2.field1) class TestChangeFieldsName(TestCase): def test_fields(self): class ContainerNewFields(Container): _fieldsname = '__newfields__' __newfields__ = ({'name': 'field1', 'doc': 'field1 doc'}, ) @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}) def __init__(self, **kwargs): super().__init__('test name') self.field1 = kwargs['field1'] self.assertTupleEqual(ContainerNewFields.__newfields__, ('field1', )) self.assertIs(ContainerNewFields._get_fields(), ContainerNewFields.__newfields__) expected = ({'doc': 'field1 doc', 'name': 'field1'}, ) self.assertTupleEqual(ContainerNewFields.get_fields_conf(), expected) def test_fields_inheritance(self): class ContainerOldFields(Container): __fields__ = ({'name': 'field1', 'doc': 'field1 doc'}, ) @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}) def __init__(self, **kwargs): super().__init__('test name') self.field1 = kwargs['field1'] class ContainerNewFields(ContainerOldFields): _fieldsname = '__newfields__' __newfields__ = ({'name': 'field2', 'doc': 'field2 doc'}, ) @docval({'name': 'field1', 'doc': 'field1 doc', 'type': None, 'default': None}, {'name': 'field2', 'doc': 'field2 doc', 'type': None, 'default': None}) def __init__(self, **kwargs): super().__init__(kwargs['field1']) self.field2 = kwargs['field2'] self.assertTupleEqual(ContainerNewFields.__newfields__, ('field1', 'field2')) self.assertIs(ContainerNewFields._get_fields(), ContainerNewFields.__newfields__) expected = ({'doc': 'field1 doc', 'name': 'field1'}, {'doc': 'field2 doc', 'name': 'field2'}, ) self.assertTupleEqual(ContainerNewFields.get_fields_conf(), expected) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/test_io_hdf5.py0000644000655200065520000002142400000000000017633 0ustar00circlecicircleciimport json import os from numbers import Number import numpy as np from h5py import File, Dataset, Reference from hdmf.backends.hdf5 import HDF5IO from hdmf.build import GroupBuilder, DatasetBuilder, LinkBuilder from hdmf.testing import TestCase from hdmf.utils import get_data_shape from tests.unit.test_io_hdf5_h5tools import _get_manager from tests.unit.utils import Foo class HDF5Encoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Dataset): ret = None for t in (list, str): try: ret = t(obj) break except: # noqa: E722 pass if ret is None: return obj else: return ret elif isinstance(obj, np.int64): return int(obj) elif isinstance(obj, bytes): return str(obj) return json.JSONEncoder.default(self, obj) class GroupBuilderTestCase(TestCase): ''' A TestCase class for comparing GroupBuilders. ''' def __is_scalar(self, obj): if hasattr(obj, 'shape'): return len(obj.shape) == 0 else: if any(isinstance(obj, t) for t in (int, str, float, bytes, str)): return True return False def __convert_h5_scalar(self, obj): if isinstance(obj, Dataset): return obj[...] return obj def __compare_attr_dicts(self, a, b): reasons = list() b_keys = set(b.keys()) for k in a: if k not in b: reasons.append("'%s' attribute missing from second dataset" % k) else: if a[k] != b[k]: reasons.append("'%s' attribute on datasets not equal" % k) b_keys.remove(k) for k in b_keys: reasons.append("'%s' attribute missing from first dataset" % k) return reasons def __compare_dataset(self, a, b): reasons = self.__compare_attr_dicts(a.attributes, b.attributes) if not self.__compare_data(a.data, b.data): reasons.append("dataset '%s' not equal" % a.name) return reasons def __compare_data(self, a, b): if isinstance(a, Number) and isinstance(b, Number): return a == b elif isinstance(a, Number) != isinstance(b, Number): return False else: a_scalar = self.__is_scalar(a) b_scalar = self.__is_scalar(b) if a_scalar and b_scalar: return self.__convert_h5_scalar(a_scalar) == self.__convert_h5_scalar(b_scalar) elif a_scalar != b_scalar: return False if len(a) == len(b): for i in range(len(a)): if not self.__compare_data(a[i], b[i]): return False else: return False return True def __fmt(self, val): return "%s (%s)" % (val, type(val)) def __assert_helper(self, a, b): reasons = list() b_keys = set(b.keys()) for k, a_sub in a.items(): if k in b: b_sub = b[k] b_keys.remove(k) if isinstance(a_sub, LinkBuilder) and isinstance(a_sub, LinkBuilder): a_sub = a_sub['builder'] b_sub = b_sub['builder'] elif isinstance(a_sub, LinkBuilder) != isinstance(a_sub, LinkBuilder): reasons.append('%s != %s' % (a_sub, b_sub)) if isinstance(a_sub, DatasetBuilder) and isinstance(a_sub, DatasetBuilder): # if not self.__compare_dataset(a_sub, b_sub): # reasons.append('%s != %s' % (a_sub, b_sub)) reasons.extend(self.__compare_dataset(a_sub, b_sub)) elif isinstance(a_sub, GroupBuilder) and isinstance(a_sub, GroupBuilder): reasons.extend(self.__assert_helper(a_sub, b_sub)) else: equal = None a_array = isinstance(a_sub, np.ndarray) b_array = isinstance(b_sub, np.ndarray) if a_array and b_array: equal = np.array_equal(a_sub, b_sub) elif a_array or b_array: # if strings, convert before comparing if b_array: if b_sub.dtype.char in ('S', 'U'): a_sub = [np.string_(s) for s in a_sub] else: if a_sub.dtype.char in ('S', 'U'): b_sub = [np.string_(s) for s in b_sub] equal = np.array_equal(a_sub, b_sub) else: equal = a_sub == b_sub if not equal: reasons.append('%s != %s' % (self.__fmt(a_sub), self.__fmt(b_sub))) else: reasons.append("'%s' not in both" % k) for k in b_keys: reasons.append("'%s' not in both" % k) return reasons def assertBuilderEqual(self, a, b): ''' Tests that two GroupBuilders are equal ''' reasons = self.__assert_helper(a, b) if len(reasons): raise AssertionError(', '.join(reasons)) return True class TestHDF5Writer(GroupBuilderTestCase): def setUp(self): self.manager = _get_manager() self.path = "test_io_hdf5.h5" self.foo_builder = GroupBuilder('foo1', attributes={'data_type': 'Foo', 'namespace': 'test_core', 'attr1': "bar", 'object_id': -1}, datasets={'my_data': DatasetBuilder('my_data', list(range(100, 200, 10)), attributes={'attr2': 17})}) self.foo = Foo('foo1', list(range(100, 200, 10)), attr1="bar", attr2=17, attr3=3.14) self.manager.prebuilt(self.foo, self.foo_builder) self.builder = GroupBuilder( 'root', source=self.path, groups={'test_bucket': GroupBuilder('test_bucket', groups={'foo_holder': GroupBuilder('foo_holder', groups={'foo1': self.foo_builder})})}, attributes={'data_type': 'FooFile'}) def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def check_fields(self): f = File(self.path, 'r') self.assertIn('test_bucket', f) bucket = f.get('test_bucket') self.assertIn('foo_holder', bucket) holder = bucket.get('foo_holder') self.assertIn('foo1', holder) return f def test_write_builder(self): writer = HDF5IO(self.path, manager=self.manager, mode='a') writer.write_builder(self.builder) writer.close() self.check_fields() def test_write_attribute_reference_container(self): writer = HDF5IO(self.path, manager=self.manager, mode='a') self.builder.set_attribute('ref_attribute', self.foo) writer.write_builder(self.builder) writer.close() f = self.check_fields() self.assertIsInstance(f.attrs['ref_attribute'], Reference) self.assertEqual(f['test_bucket/foo_holder/foo1'], f[f.attrs['ref_attribute']]) def test_write_attribute_reference_builder(self): writer = HDF5IO(self.path, manager=self.manager, mode='a') self.builder.set_attribute('ref_attribute', self.foo_builder) writer.write_builder(self.builder) writer.close() f = self.check_fields() self.assertIsInstance(f.attrs['ref_attribute'], Reference) self.assertEqual(f['test_bucket/foo_holder/foo1'], f[f.attrs['ref_attribute']]) def test_write_context_manager(self): with HDF5IO(self.path, manager=self.manager, mode='a') as writer: writer.write_builder(self.builder) self.check_fields() def test_read_builder(self): self.maxDiff = None io = HDF5IO(self.path, manager=self.manager, mode='a') io.write_builder(self.builder) builder = io.read_builder() self.assertBuilderEqual(builder, self.builder) io.close() def test_dataset_shape(self): self.maxDiff = None io = HDF5IO(self.path, manager=self.manager, mode='a') io.write_builder(self.builder) builder = io.read_builder() dset = builder['test_bucket']['foo_holder']['foo1']['my_data'].data self.assertEqual(get_data_shape(dset), (10,)) io.close() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/test_io_hdf5_h5tools.py0000644000655200065520000041730700000000000021321 0ustar00circlecicircleciimport os import unittest import warnings from io import BytesIO from pathlib import Path import h5py import numpy as np from h5py import SoftLink, HardLink, ExternalLink, File from h5py import filters as h5py_filters from hdmf.backends.hdf5 import H5DataIO from hdmf.backends.hdf5.h5tools import HDF5IO, ROOT_NAME, SPEC_LOC_ATTR, H5PY_3 from hdmf.backends.io import HDMFIO, UnsupportedOperation from hdmf.backends.warnings import BrokenLinkWarning from hdmf.build import (GroupBuilder, DatasetBuilder, BuildManager, TypeMap, ObjectMapper, OrphanContainerBuildError, LinkBuilder) from hdmf.container import Container, Data from hdmf.data_utils import DataChunkIterator, InvalidDataIOError from hdmf.spec.catalog import SpecCatalog from hdmf.spec.namespace import NamespaceCatalog from hdmf.spec.namespace import SpecNamespace from hdmf.spec.spec import (AttributeSpec, DatasetSpec, GroupSpec, LinkSpec, ZERO_OR_MANY, ONE_OR_MANY, ZERO_OR_ONE, RefSpec, DtypeSpec) from hdmf.testing import TestCase from hdmf.utils import docval, getargs from tests.unit.utils import (Foo, FooBucket, CORE_NAMESPACE, get_temp_filepath, CustomGroupSpec, CustomDatasetSpec, CustomSpecNamespace) class FooFile(Container): @docval({'name': 'buckets', 'type': list, 'doc': 'the FooBuckets in this file', 'default': list()}, {'name': 'foo_link', 'type': Foo, 'doc': 'an optional linked Foo', 'default': None}, {'name': 'foofile_data', 'type': 'array_data', 'doc': 'an optional dataset', 'default': None}, {'name': 'foo_ref_attr', 'type': Foo, 'doc': 'a reference Foo', 'default': None}, ) def __init__(self, **kwargs): buckets, foo_link, foofile_data, foo_ref_attr = getargs('buckets', 'foo_link', 'foofile_data', 'foo_ref_attr', kwargs) super().__init__(name=ROOT_NAME) # name is not used - FooFile should be the root container self.__buckets = {b.name: b for b in buckets} # note: collections of groups are unordered in HDF5 for f in buckets: f.parent = self self.__foo_link = foo_link self.__foofile_data = foofile_data self.__foo_ref_attr = foo_ref_attr def __eq__(self, other): return (self.buckets == other.buckets and self.foo_link == other.foo_link and self.foofile_data == other.foofile_data) def __str__(self): return ('buckets=%s, foo_link=%s, foofile_data=%s' % (self.buckets, self.foo_link, self.foofile_data)) @property def buckets(self): return self.__buckets def add_bucket(self, bucket): self.__buckets[bucket.name] = bucket bucket.parent = self def remove_bucket(self, bucket_name): bucket = self.__buckets.pop(bucket_name) if bucket.parent is self: self._remove_child(bucket) return bucket @property def foo_link(self): return self.__foo_link @foo_link.setter def foo_link(self, value): if self.__foo_link is None: self.__foo_link = value else: raise ValueError("can't reset foo_link attribute") @property def foofile_data(self): return self.__foofile_data @foofile_data.setter def foofile_data(self, value): if self.__foofile_data is None: self.__foofile_data = value else: raise ValueError("can't reset foofile_data attribute") @property def foo_ref_attr(self): return self.__foo_ref_attr @foo_ref_attr.setter def foo_ref_attr(self, value): if self.__foo_ref_attr is None: self.__foo_ref_attr = value else: raise ValueError("can't reset foo_ref_attr attribute") class H5IOTest(TestCase): """Tests for h5tools IO tools""" def setUp(self): self.path = get_temp_filepath() self.io = HDF5IO(self.path, mode='a') self.f = self.io._file def tearDown(self): self.io.close() os.remove(self.path) ########################################## # __chunked_iter_fill__(...) tests ########################################## def test__chunked_iter_fill(self): """Matrix test of HDF5IO.__chunked_iter_fill__ using a DataChunkIterator with different parameters""" data_opts = {'iterator': range(10), 'numpy': np.arange(30).reshape(5, 2, 3), 'list': np.arange(30).reshape(5, 2, 3).tolist(), 'sparselist1': [1, 2, 3, None, None, None, None, 8, 9, 10], 'sparselist2': [None, None, 3], 'sparselist3': [1, 2, 3, None, None], # note: cannot process None in ndarray 'nanlist': [[[1, 2, 3, np.nan, np.nan, 6], [np.nan, np.nan, 3, 4, np.nan, np.nan]], [[10, 20, 30, 40, np.nan, np.nan], [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan]]]} buffer_size_opts = [1, 2, 3, 4] # data is divisible by some of these, some not for data_type, data in data_opts.items(): iter_axis_opts = [0, 1, 2] if data_type == 'iterator' or data_type.startswith('sparselist'): iter_axis_opts = [0] # only one dimension for iter_axis in iter_axis_opts: for buffer_size in buffer_size_opts: with self.subTest(data_type=data_type, iter_axis=iter_axis, buffer_size=buffer_size): with warnings.catch_warnings(record=True) as w: dci = DataChunkIterator(data=data, buffer_size=buffer_size, iter_axis=iter_axis) if len(w) <= 1: # init may throw UserWarning for iterating over not-first dim of a list. ignore here pass dset_name = '%s, %d, %d' % (data_type, iter_axis, buffer_size) my_dset = HDF5IO.__chunked_iter_fill__(self.f, dset_name, dci) if data_type == 'iterator': self.assertListEqual(my_dset[:].tolist(), list(data)) elif data_type == 'numpy': self.assertTrue(np.all(my_dset[:] == data)) self.assertTupleEqual(my_dset.shape, data.shape) elif data_type == 'list' or data_type == 'nanlist': data_np = np.array(data) np.testing.assert_array_equal(my_dset[:], data_np) self.assertTupleEqual(my_dset.shape, data_np.shape) elif data_type.startswith('sparselist'): # replace None in original data with default hdf5 fillvalue 0 data_zeros = np.where(np.equal(np.array(data), None), 0, data) np.testing.assert_array_equal(my_dset[:], data_zeros) self.assertTupleEqual(my_dset.shape, data_zeros.shape) ########################################## # write_dataset tests: scalars ########################################## def test_write_dataset_scalar(self): a = 10 self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTupleEqual(dset.shape, ()) self.assertEqual(dset[()], a) def test_write_dataset_string(self): a = 'test string' self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTupleEqual(dset.shape, ()) # self.assertEqual(dset[()].decode('utf-8'), a) read_a = dset[()] if isinstance(read_a, bytes): read_a = read_a.decode('utf-8') self.assertEqual(read_a, a) ########################################## # write_dataset tests: lists ########################################## def test_write_dataset_list(self): a = np.arange(30).reshape(5, 2, 3) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a.tolist(), attributes={})) dset = self.f['test_dataset'] self.assertTrue(np.all(dset[:] == a)) def test_write_dataset_list_compress_gzip(self): a = H5DataIO(np.arange(30).reshape(5, 2, 3), compression='gzip', compression_opts=5, shuffle=True, fletcher32=True) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.compression, 'gzip') self.assertEqual(dset.compression_opts, 5) self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) @unittest.skipIf("lzf" not in h5py_filters.encode, "LZF compression not supported in this h5py library install") def test_write_dataset_list_compress_lzf(self): warn_msg = ("lzf compression may not be available on all installations of HDF5. Use of gzip is " "recommended to ensure portability of the generated HDF5 files.") with self.assertWarnsWith(UserWarning, warn_msg): a = H5DataIO(np.arange(30).reshape(5, 2, 3), compression='lzf', shuffle=True, fletcher32=True) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.compression, 'lzf') self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) @unittest.skipIf("szip" not in h5py_filters.encode, "SZIP compression not supported in this h5py library install") def test_write_dataset_list_compress_szip(self): warn_msg = ("szip compression may not be available on all installations of HDF5. Use of gzip is " "recommended to ensure portability of the generated HDF5 files.") with self.assertWarnsWith(UserWarning, warn_msg): a = H5DataIO(np.arange(30).reshape(5, 2, 3), compression='szip', compression_opts=('ec', 16), shuffle=True, fletcher32=True) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.compression, 'szip') self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) def test_write_dataset_list_compress_available_int_filters(self): a = H5DataIO(np.arange(30).reshape(5, 2, 3), compression=1, shuffle=True, fletcher32=True, allow_plugin_filters=True) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.compression, 'gzip') self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) def test_write_dataset_list_enable_default_compress(self): a = H5DataIO(np.arange(30).reshape(5, 2, 3), compression=True) self.assertEqual(a.io_settings['compression'], 'gzip') self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.compression, 'gzip') def test_write_dataset_list_disable_default_compress(self): with warnings.catch_warnings(record=True) as w: a = H5DataIO(np.arange(30).reshape(5, 2, 3), compression=False, compression_opts=5) self.assertEqual(len(w), 1) # We expect a warning that compression options are being ignored self.assertFalse('compression_ops' in a.io_settings) self.assertFalse('compression' in a.io_settings) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.compression, None) def test_write_dataset_list_chunked(self): a = H5DataIO(np.arange(30).reshape(5, 2, 3), chunks=(1, 1, 3)) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.chunks, (1, 1, 3)) def test_write_dataset_list_fillvalue(self): a = H5DataIO(np.arange(20).reshape(5, 4), fillvalue=-1) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={})) dset = self.f['test_dataset'] self.assertTrue(np.all(dset[:] == a.data)) self.assertEqual(dset.fillvalue, -1) ########################################## # write_dataset tests: tables ########################################## def test_write_table(self): cmpd_dt = np.dtype([('a', np.int32), ('b', np.float64)]) data = np.zeros(10, dtype=cmpd_dt) data['a'][1] = 101 data['b'][1] = 0.1 dt = [{'name': 'a', 'dtype': 'int32', 'doc': 'a column'}, {'name': 'b', 'dtype': 'float64', 'doc': 'b column'}] self.io.write_dataset(self.f, DatasetBuilder('test_dataset', data, attributes={}, dtype=dt)) dset = self.f['test_dataset'] self.assertEqual(dset['a'].tolist(), data['a'].tolist()) self.assertEqual(dset['b'].tolist(), data['b'].tolist()) def test_write_table_nested(self): b_cmpd_dt = np.dtype([('c', np.int32), ('d', np.float64)]) cmpd_dt = np.dtype([('a', np.int32), ('b', b_cmpd_dt)]) data = np.zeros(10, dtype=cmpd_dt) data['a'][1] = 101 data['b']['c'] = 202 data['b']['d'] = 10.1 b_dt = [{'name': 'c', 'dtype': 'int32', 'doc': 'c column'}, {'name': 'd', 'dtype': 'float64', 'doc': 'd column'}] dt = [{'name': 'a', 'dtype': 'int32', 'doc': 'a column'}, {'name': 'b', 'dtype': b_dt, 'doc': 'b column'}] self.io.write_dataset(self.f, DatasetBuilder('test_dataset', data, attributes={}, dtype=dt)) dset = self.f['test_dataset'] self.assertEqual(dset['a'].tolist(), data['a'].tolist()) self.assertEqual(dset['b'].tolist(), data['b'].tolist()) ########################################## # write_dataset tests: Iterable ########################################## def test_write_dataset_iterable(self): self.io.write_dataset(self.f, DatasetBuilder('test_dataset', range(10), attributes={})) dset = self.f['test_dataset'] self.assertListEqual(dset[:].tolist(), list(range(10))) def test_write_dataset_iterable_multidimensional_array(self): a = np.arange(30).reshape(5, 2, 3) aiter = iter(a) daiter = DataChunkIterator.from_iterable(aiter, buffer_size=2) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', daiter, attributes={})) dset = self.f['test_dataset'] self.assertListEqual(dset[:].tolist(), a.tolist()) def test_write_multi_dci_oaat(self): """ Test writing multiple DataChunkIterators, one at a time """ a = np.arange(30).reshape(5, 2, 3) b = np.arange(30, 60).reshape(5, 2, 3) aiter = iter(a) biter = iter(b) daiter1 = DataChunkIterator.from_iterable(aiter, buffer_size=2) daiter2 = DataChunkIterator.from_iterable(biter, buffer_size=2) builder = GroupBuilder("root") dataset1 = DatasetBuilder('test_dataset1', daiter1) dataset2 = DatasetBuilder('test_dataset2', daiter2) builder.set_dataset(dataset1) builder.set_dataset(dataset2) self.io.write_builder(builder) dset1 = self.f['test_dataset1'] self.assertListEqual(dset1[:].tolist(), a.tolist()) dset2 = self.f['test_dataset2'] self.assertListEqual(dset2[:].tolist(), b.tolist()) def test_write_multi_dci_conc(self): """ Test writing multiple DataChunkIterators, concurrently """ a = np.arange(30).reshape(5, 2, 3) b = np.arange(30, 60).reshape(5, 2, 3) aiter = iter(a) biter = iter(b) daiter1 = DataChunkIterator.from_iterable(aiter, buffer_size=2) daiter2 = DataChunkIterator.from_iterable(biter, buffer_size=2) builder = GroupBuilder("root") dataset1 = DatasetBuilder('test_dataset1', daiter1) dataset2 = DatasetBuilder('test_dataset2', daiter2) builder.set_dataset(dataset1) builder.set_dataset(dataset2) self.io.write_builder(builder) dset1 = self.f['test_dataset1'] self.assertListEqual(dset1[:].tolist(), a.tolist()) dset2 = self.f['test_dataset2'] self.assertListEqual(dset2[:].tolist(), b.tolist()) def test_write_dataset_iterable_multidimensional_array_compression(self): a = np.arange(30).reshape(5, 2, 3) aiter = iter(a) daiter = DataChunkIterator.from_iterable(aiter, buffer_size=2) wrapped_daiter = H5DataIO(data=daiter, compression='gzip', compression_opts=5, shuffle=True, fletcher32=True) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', wrapped_daiter, attributes={})) dset = self.f['test_dataset'] self.assertEqual(dset.shape, a.shape) self.assertListEqual(dset[:].tolist(), a.tolist()) self.assertEqual(dset.compression, 'gzip') self.assertEqual(dset.compression_opts, 5) self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) ############################################# # write_dataset tests: data chunk iterator ############################################# def test_write_dataset_data_chunk_iterator(self): dci = DataChunkIterator(data=np.arange(10), buffer_size=2) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', dci, attributes={}, dtype=dci.dtype)) dset = self.f['test_dataset'] self.assertListEqual(dset[:].tolist(), list(range(10))) self.assertEqual(dset[:].dtype, dci.dtype) def test_write_dataset_data_chunk_iterator_with_compression(self): dci = DataChunkIterator(data=np.arange(10), buffer_size=2) wrapped_dci = H5DataIO(data=dci, compression='gzip', compression_opts=5, shuffle=True, fletcher32=True, chunks=(2,)) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', wrapped_dci, attributes={})) dset = self.f['test_dataset'] self.assertListEqual(dset[:].tolist(), list(range(10))) self.assertEqual(dset.compression, 'gzip') self.assertEqual(dset.compression_opts, 5) self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) self.assertEqual(dset.chunks, (2,)) def test_pass_through_of_recommended_chunks(self): class DC(DataChunkIterator): def recommended_chunk_shape(self): return (5, 1, 1) dci = DC(data=np.arange(30).reshape(5, 2, 3)) wrapped_dci = H5DataIO(data=dci, compression='gzip', compression_opts=5, shuffle=True, fletcher32=True) self.io.write_dataset(self.f, DatasetBuilder('test_dataset', wrapped_dci, attributes={})) dset = self.f['test_dataset'] self.assertEqual(dset.chunks, (5, 1, 1)) self.assertEqual(dset.compression, 'gzip') self.assertEqual(dset.compression_opts, 5) self.assertEqual(dset.shuffle, True) self.assertEqual(dset.fletcher32, True) def test_dci_h5dataset(self): data = np.arange(30).reshape(5, 2, 3) dci1 = DataChunkIterator(data=data, buffer_size=1, iter_axis=0) HDF5IO.__chunked_iter_fill__(self.f, 'test_dataset', dci1) dset = self.f['test_dataset'] dci2 = DataChunkIterator(data=dset, buffer_size=2, iter_axis=2) chunk = dci2.next() self.assertTupleEqual(chunk.shape, (5, 2, 2)) chunk = dci2.next() self.assertTupleEqual(chunk.shape, (5, 2, 1)) # TODO test chunk data, shape, selection self.assertTupleEqual(dci2.recommended_data_shape(), data.shape) self.assertIsNone(dci2.recommended_chunk_shape()) def test_dci_h5dataset_sparse_matched(self): data = [1, 2, 3, None, None, None, None, 8, 9, 10] dci1 = DataChunkIterator(data=data, buffer_size=3) HDF5IO.__chunked_iter_fill__(self.f, 'test_dataset', dci1) dset = self.f['test_dataset'] dci2 = DataChunkIterator(data=dset, buffer_size=2) # dataset is read such that Nones in original data were not written, but are read as 0s self.assertTupleEqual(dci2.maxshape, (10,)) self.assertEqual(dci2.dtype, np.dtype(int)) count = 0 for chunk in dci2: self.assertEqual(len(chunk.selection), 1) if count == 0: self.assertListEqual(chunk.data.tolist(), [1, 2]) self.assertEqual(chunk.selection[0], slice(0, 2)) elif count == 1: self.assertListEqual(chunk.data.tolist(), [3, 0]) self.assertEqual(chunk.selection[0], slice(2, 4)) elif count == 2: self.assertListEqual(chunk.data.tolist(), [0, 0]) self.assertEqual(chunk.selection[0], slice(4, 6)) elif count == 3: self.assertListEqual(chunk.data.tolist(), [0, 8]) self.assertEqual(chunk.selection[0], slice(6, 8)) elif count == 4: self.assertListEqual(chunk.data.tolist(), [9, 10]) self.assertEqual(chunk.selection[0], slice(8, 10)) count += 1 self.assertEqual(count, 5) self.assertTupleEqual(dci2.recommended_data_shape(), (10,)) self.assertIsNone(dci2.recommended_chunk_shape()) def test_dci_h5dataset_sparse_unmatched(self): data = [1, 2, 3, None, None, None, None, 8, 9, 10] dci1 = DataChunkIterator(data=data, buffer_size=3) HDF5IO.__chunked_iter_fill__(self.f, 'test_dataset', dci1) dset = self.f['test_dataset'] dci2 = DataChunkIterator(data=dset, buffer_size=4) # dataset is read such that Nones in original data were not written, but are read as 0s self.assertTupleEqual(dci2.maxshape, (10,)) self.assertEqual(dci2.dtype, np.dtype(int)) count = 0 for chunk in dci2: self.assertEqual(len(chunk.selection), 1) if count == 0: self.assertListEqual(chunk.data.tolist(), [1, 2, 3, 0]) self.assertEqual(chunk.selection[0], slice(0, 4)) elif count == 1: self.assertListEqual(chunk.data.tolist(), [0, 0, 0, 8]) self.assertEqual(chunk.selection[0], slice(4, 8)) elif count == 2: self.assertListEqual(chunk.data.tolist(), [9, 10]) self.assertEqual(chunk.selection[0], slice(8, 10)) count += 1 self.assertEqual(count, 3) self.assertTupleEqual(dci2.recommended_data_shape(), (10,)) self.assertIsNone(dci2.recommended_chunk_shape()) def test_dci_h5dataset_scalar(self): data = [1] dci1 = DataChunkIterator(data=data, buffer_size=3) HDF5IO.__chunked_iter_fill__(self.f, 'test_dataset', dci1) dset = self.f['test_dataset'] dci2 = DataChunkIterator(data=dset, buffer_size=4) # dataset is read such that Nones in original data were not written, but are read as 0s self.assertTupleEqual(dci2.maxshape, (1,)) self.assertEqual(dci2.dtype, np.dtype(int)) count = 0 for chunk in dci2: self.assertEqual(len(chunk.selection), 1) if count == 0: self.assertListEqual(chunk.data.tolist(), [1]) self.assertEqual(chunk.selection[0], slice(0, 1)) count += 1 self.assertEqual(count, 1) self.assertTupleEqual(dci2.recommended_data_shape(), (1,)) self.assertIsNone(dci2.recommended_chunk_shape()) ############################################# # H5DataIO general ############################################# def test_warning_on_non_gzip_compression(self): # Make sure no warning is issued when using gzip with warnings.catch_warnings(record=True) as w: dset = H5DataIO(np.arange(30), compression='gzip') self.assertEqual(len(w), 0) self.assertEqual(dset.io_settings['compression'], 'gzip') # Make sure a warning is issued when using szip (even if installed) warn_msg = ("szip compression may not be available on all installations of HDF5. Use of gzip is " "recommended to ensure portability of the generated HDF5 files.") if "szip" in h5py_filters.encode: with self.assertWarnsWith(UserWarning, warn_msg): dset = H5DataIO(np.arange(30), compression='szip', compression_opts=('ec', 16)) self.assertEqual(dset.io_settings['compression'], 'szip') else: with self.assertRaises(ValueError): with self.assertWarnsWith(UserWarning, warn_msg): dset = H5DataIO(np.arange(30), compression='szip', compression_opts=('ec', 16)) self.assertEqual(dset.io_settings['compression'], 'szip') # Make sure a warning is issued when using lzf compression warn_msg = ("lzf compression may not be available on all installations of HDF5. Use of gzip is " "recommended to ensure portability of the generated HDF5 files.") with self.assertWarnsWith(UserWarning, warn_msg): dset = H5DataIO(np.arange(30), compression='lzf') self.assertEqual(dset.io_settings['compression'], 'lzf') def test_error_on_unsupported_compression_filter(self): # Make sure gzip does not raise an error try: H5DataIO(np.arange(30), compression='gzip', compression_opts=5) except ValueError: self.fail("Using gzip compression raised a ValueError when it should not") # Make sure szip raises an error if not installed (or does not raise an error if installed) warn_msg = ("szip compression may not be available on all installations of HDF5. Use of gzip is " "recommended to ensure portability of the generated HDF5 files.") if "szip" not in h5py_filters.encode: with self.assertRaises(ValueError): with self.assertWarnsWith(UserWarning, warn_msg): H5DataIO(np.arange(30), compression='szip', compression_opts=('ec', 16)) else: try: with self.assertWarnsWith(UserWarning, warn_msg): H5DataIO(np.arange(30), compression='szip', compression_opts=('ec', 16)) except ValueError: self.fail("SZIP is installed but H5DataIO still raises an error") # Test error on illegal (i.e., a made-up compressor) with self.assertRaises(ValueError): warn_msg = ("unknown compression may not be available on all installations of HDF5. Use of gzip is " "recommended to ensure portability of the generated HDF5 files.") with self.assertWarnsWith(UserWarning, warn_msg): H5DataIO(np.arange(30), compression="unknown") # Make sure passing int compression filter raise an error if not installed if not h5py_filters.h5z.filter_avail(h5py_filters.h5z.FILTER_MAX): with self.assertRaises(ValueError): warn_msg = ("%i compression may not be available on all installations of HDF5. Use of gzip is " "recommended to ensure portability of the generated HDF5 files." % h5py_filters.h5z.FILTER_MAX) with self.assertWarnsWith(UserWarning, warn_msg): H5DataIO(np.arange(30), compression=h5py_filters.h5z.FILTER_MAX, allow_plugin_filters=True) # Make sure available int compression filters raise an error without passing allow_plugin_filters=True with self.assertRaises(ValueError): H5DataIO(np.arange(30), compression=h5py_filters.h5z.FILTER_DEFLATE) def test_value_error_on_incompatible_compression_opts(self): # Make sure we warn when gzip with szip compression options is used with self.assertRaises(ValueError): H5DataIO(np.arange(30), compression='gzip', compression_opts=('ec', 16)) # Make sure we warn if gzip with a too high agression is used with self.assertRaises(ValueError): H5DataIO(np.arange(30), compression='gzip', compression_opts=100) # Make sure we warn if lzf with gzip compression option is used with self.assertRaises(ValueError): H5DataIO(np.arange(30), compression='lzf', compression_opts=5) # Make sure we warn if lzf with szip compression option is used with self.assertRaises(ValueError): H5DataIO(np.arange(30), compression='lzf', compression_opts=('ec', 16)) # Make sure we warn if szip with gzip compression option is used with self.assertRaises(ValueError): H5DataIO(np.arange(30), compression='szip', compression_opts=4) # Make sure szip raises a ValueError if bad options are used (odd compression option) with self.assertRaises(ValueError): H5DataIO(np.arange(30), compression='szip', compression_opts=('ec', 3)) # Make sure szip raises a ValueError if bad options are used (bad methos) with self.assertRaises(ValueError): H5DataIO(np.arange(30), compression='szip', compression_opts=('bad_method', 16)) def test_warning_on_linking_of_regular_array(self): with warnings.catch_warnings(record=True) as w: dset = H5DataIO(np.arange(30), link_data=True) self.assertEqual(len(w), 1) self.assertEqual(dset.link_data, False) def test_warning_on_setting_io_options_on_h5dataset_input(self): self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) with warnings.catch_warnings(record=True) as w: H5DataIO(self.f['test_dataset'], compression='gzip', compression_opts=4, fletcher32=True, shuffle=True, maxshape=(10, 20), chunks=(10,), fillvalue=100) self.assertEqual(len(w), 7) def test_h5dataio_array_conversion_numpy(self): # Test that H5DataIO.__array__ is working when wrapping an ndarray test_speed = np.array([10., 20.]) data = H5DataIO((test_speed)) self.assertTrue(np.all(np.isfinite(data))) # Force call of H5DataIO.__array__ def test_h5dataio_array_conversion_list(self): # Test that H5DataIO.__array__ is working when wrapping a python list test_speed = [10., 20.] data = H5DataIO(test_speed) self.assertTrue(np.all(np.isfinite(data))) # Force call of H5DataIO.__array__ def test_h5dataio_array_conversion_datachunkiterator(self): # Test that H5DataIO.__array__ is working when wrapping a python list test_speed = DataChunkIterator(data=[10., 20.]) data = H5DataIO(test_speed) with self.assertRaises(NotImplementedError): np.isfinite(data) # Force call of H5DataIO.__array__ ############################################# # Copy/Link h5py.Dataset object ############################################# def test_link_h5py_dataset_input(self): self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) self.io.write_dataset(self.f, DatasetBuilder('test_softlink', self.f['test_dataset'], attributes={})) self.assertTrue(isinstance(self.f.get('test_softlink', getlink=True), SoftLink)) def test_copy_h5py_dataset_input(self): self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) self.io.write_dataset(self.f, DatasetBuilder('test_copy', self.f['test_dataset'], attributes={}), link_data=False) self.assertTrue(isinstance(self.f.get('test_copy', getlink=True), HardLink)) self.assertListEqual(self.f['test_dataset'][:].tolist(), self.f['test_copy'][:].tolist()) def test_link_h5py_dataset_h5dataio_input(self): self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) self.io.write_dataset(self.f, DatasetBuilder('test_softlink', H5DataIO(data=self.f['test_dataset'], link_data=True), attributes={})) self.assertTrue(isinstance(self.f.get('test_softlink', getlink=True), SoftLink)) def test_copy_h5py_dataset_h5dataio_input(self): self.io.write_dataset(self.f, DatasetBuilder('test_dataset', np.arange(10), attributes={})) self.io.write_dataset(self.f, DatasetBuilder('test_copy', H5DataIO(data=self.f['test_dataset'], link_data=False), # Force dataset copy attributes={})) # Make sure the default behavior is set to link the data self.assertTrue(isinstance(self.f.get('test_copy', getlink=True), HardLink)) self.assertListEqual(self.f['test_dataset'][:].tolist(), self.f['test_copy'][:].tolist()) def test_list_fill_empty(self): dset = self.io.__list_fill__(self.f, 'empty_dataset', [], options={'dtype': int, 'io_settings': {}}) self.assertTupleEqual(dset.shape, (0,)) def test_list_fill_empty_no_dtype(self): with self.assertRaisesRegex(Exception, r"cannot add \S+ to [/\S]+ - could not determine type"): self.io.__list_fill__(self.f, 'empty_dataset', []) def test_read_str(self): a = ['a', 'bb', 'ccc', 'dddd', 'e'] attr = 'foobar' self.io.write_dataset(self.f, DatasetBuilder('test_dataset', a, attributes={'test_attr': attr}, dtype='text')) self.io.close() with HDF5IO(self.path, 'r') as io: bldr = io.read_builder() np.array_equal(bldr['test_dataset'].data[:], ['a', 'bb', 'ccc', 'dddd', 'e']) np.array_equal(bldr['test_dataset'].attributes['test_attr'], attr) if H5PY_3: self.assertEqual(str(bldr['test_dataset'].data), '') else: self.assertEqual(str(bldr['test_dataset'].data), '') def _get_manager(): foo_spec = GroupSpec('A test group specification with a data type', data_type_def='Foo', datasets=[DatasetSpec('an example dataset', 'int', name='my_data', attributes=[AttributeSpec('attr2', 'an example integer attribute', 'int')])], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text'), AttributeSpec('attr3', 'an example float attribute', 'float')]) tmp_spec = GroupSpec('A subgroup for Foos', name='foo_holder', groups=[GroupSpec('the Foos in this bucket', data_type_inc='Foo', quantity=ZERO_OR_MANY)]) bucket_spec = GroupSpec('A test group specification for a data type containing data type', data_type_def='FooBucket', groups=[tmp_spec]) class FooMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) my_data_spec = spec.get_dataset('my_data') self.map_spec('attr2', my_data_spec.get_attribute('attr2')) class BucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) foo_holder_spec = spec.get_group('foo_holder') self.unmap(foo_holder_spec) foo_spec = foo_holder_spec.get_data_type('Foo') self.map_spec('foos', foo_spec) file_links_spec = GroupSpec('Foo link group', name='links', links=[LinkSpec('Foo link', name='foo_link', target_type='Foo', quantity=ZERO_OR_ONE)] ) file_spec = GroupSpec("A file of Foos contained in FooBuckets", data_type_def='FooFile', groups=[GroupSpec('Holds the FooBuckets', name='buckets', groups=[GroupSpec("One or more FooBuckets", data_type_inc='FooBucket', quantity=ZERO_OR_MANY)]), file_links_spec], datasets=[DatasetSpec('Foo data', name='foofile_data', dtype='int', quantity=ZERO_OR_ONE)], attributes=[AttributeSpec(doc='Foo ref attr', name='foo_ref_attr', dtype=RefSpec('Foo', 'object'), required=False)], ) class FileMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) bucket_spec = spec.get_group('buckets').get_data_type('FooBucket') self.map_spec('buckets', bucket_spec) self.unmap(spec.get_group('links')) foo_link_spec = spec.get_group('links').get_link('foo_link') self.map_spec('foo_link', foo_link_spec) spec_catalog = SpecCatalog() spec_catalog.register_spec(foo_spec, 'test.yaml') spec_catalog.register_spec(bucket_spec, 'test.yaml') spec_catalog.register_spec(file_spec, 'test.yaml') namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) type_map.register_container_type(CORE_NAMESPACE, 'Foo', Foo) type_map.register_container_type(CORE_NAMESPACE, 'FooBucket', FooBucket) type_map.register_container_type(CORE_NAMESPACE, 'FooFile', FooFile) type_map.register_map(Foo, FooMapper) type_map.register_map(FooBucket, BucketMapper) type_map.register_map(FooFile, FileMapper) manager = BuildManager(type_map) return manager class TestRoundTrip(TestCase): def setUp(self): self.manager = _get_manager() self.path = get_temp_filepath() def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def test_roundtrip_basic(self): # Setup all the data we need foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.path, manager=self.manager, mode='w') as io: io.write(foofile) with HDF5IO(self.path, manager=self.manager, mode='r') as io: read_foofile = io.read() self.assertListEqual(foofile.buckets['bucket1'].foos['foo1'].my_data, read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) def test_roundtrip_empty_dataset(self): foo1 = Foo('foo1', [], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.path, manager=self.manager, mode='w') as io: io.write(foofile) with HDF5IO(self.path, manager=self.manager, mode='r') as io: read_foofile = io.read() self.assertListEqual([], read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) def test_roundtrip_empty_group(self): foobucket = FooBucket('bucket1', []) foofile = FooFile([foobucket]) with HDF5IO(self.path, manager=self.manager, mode='w') as io: io.write(foofile) with HDF5IO(self.path, manager=self.manager, mode='r') as io: read_foofile = io.read() self.assertDictEqual({}, read_foofile.buckets['bucket1'].foos) def test_roundtrip_pathlib_path(self): pathlib_path = Path(self.path) foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(pathlib_path, manager=self.manager, mode='w') as io: io.write(foofile) with HDF5IO(pathlib_path, manager=self.manager, mode='r') as io: read_foofile = io.read() self.assertListEqual(foofile.buckets['bucket1'].foos['foo1'].my_data, read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) class TestHDF5IO(TestCase): def setUp(self): self.manager = _get_manager() self.path = get_temp_filepath() foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) self.foofile = FooFile([foobucket]) self.file_obj = None def tearDown(self): if os.path.exists(self.path): os.remove(self.path) if self.file_obj is not None: fn = self.file_obj.filename self.file_obj.close() if os.path.exists(fn): os.remove(fn) def test_constructor(self): with HDF5IO(self.path, manager=self.manager, mode='w') as io: self.assertEqual(io.manager, self.manager) self.assertEqual(io.source, self.path) def test_set_file_mismatch(self): self.file_obj = File(get_temp_filepath(), 'w') err_msg = ("You argued %s as this object's path, but supplied a file with filename: %s" % (self.path, self.file_obj.filename)) with self.assertRaisesWith(ValueError, err_msg): HDF5IO(self.path, manager=self.manager, mode='w', file=self.file_obj) def test_pathlib_path(self): pathlib_path = Path(self.path) with HDF5IO(pathlib_path, mode='w') as io: self.assertEqual(io.source, self.path) class TestCacheSpec(TestCase): def setUp(self): self.manager = _get_manager() self.path = get_temp_filepath() def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def test_cache_spec(self): foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) foo2 = Foo('foo2', [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) foobucket = FooBucket('bucket1', [foo1, foo2]) foofile = FooFile([foobucket]) with HDF5IO(self.path, manager=self.manager, mode='w') as io: io.write(foofile) ns_catalog = NamespaceCatalog() HDF5IO.load_namespaces(ns_catalog, self.path) self.assertEqual(ns_catalog.namespaces, (CORE_NAMESPACE,)) source_types = self.__get_types(io.manager.namespace_catalog) read_types = self.__get_types(ns_catalog) self.assertSetEqual(source_types, read_types) def __get_types(self, catalog): types = set() for ns_name in catalog.namespaces: ns = catalog.get_namespace(ns_name) for source in ns['schema']: types.update(catalog.get_types(source['source'])) return types class TestNoCacheSpec(TestCase): def setUp(self): self.manager = _get_manager() self.path = get_temp_filepath() def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def test_no_cache_spec(self): # Setup all the data we need foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) foo2 = Foo('foo2', [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) foobucket = FooBucket('bucket1', [foo1, foo2]) foofile = FooFile([foobucket]) with HDF5IO(self.path, manager=self.manager, mode='w') as io: io.write(foofile, cache_spec=False) with File(self.path, 'r') as f: self.assertNotIn('specifications', f) class TestMultiWrite(TestCase): def setUp(self): self.path = get_temp_filepath() foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) foo2 = Foo('foo2', [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) foobucket = FooBucket('bucket1', [foo1, foo2]) self.foofile = FooFile([foobucket]) def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def test_double_write_new_manager(self): """Test writing to a container in write mode twice using a new manager without changing the container.""" with HDF5IO(self.path, manager=_get_manager(), mode='w') as io: io.write(self.foofile) with HDF5IO(self.path, manager=_get_manager(), mode='w') as io: io.write(self.foofile) # check that new bucket was written with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: read_foofile = io.read() self.assertContainerEqual(read_foofile, self.foofile) def test_double_write_same_manager(self): """Test writing to a container in write mode twice using the same manager without changing the container.""" manager = _get_manager() with HDF5IO(self.path, manager=manager, mode='w') as io: io.write(self.foofile) with HDF5IO(self.path, manager=manager, mode='w') as io: io.write(self.foofile) # check that new bucket was written with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: read_foofile = io.read() self.assertContainerEqual(read_foofile, self.foofile) @unittest.skip('Functionality not yet supported') def test_double_append_new_manager(self): """Test writing to a container in append mode twice using a new manager without changing the container.""" with HDF5IO(self.path, manager=_get_manager(), mode='a') as io: io.write(self.foofile) with HDF5IO(self.path, manager=_get_manager(), mode='a') as io: io.write(self.foofile) # check that new bucket was written with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: read_foofile = io.read() self.assertContainerEqual(read_foofile, self.foofile) @unittest.skip('Functionality not yet supported') def test_double_append_same_manager(self): """Test writing to a container in append mode twice using the same manager without changing the container.""" manager = _get_manager() with HDF5IO(self.path, manager=manager, mode='a') as io: io.write(self.foofile) with HDF5IO(self.path, manager=manager, mode='a') as io: io.write(self.foofile) # check that new bucket was written with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: read_foofile = io.read() self.assertContainerEqual(read_foofile, self.foofile) def test_write_add_write(self): """Test writing a container, adding to the in-memory container, then overwriting the same file.""" manager = _get_manager() with HDF5IO(self.path, manager=manager, mode='w') as io: io.write(self.foofile) # append new container to in-memory container foo3 = Foo('foo3', [10, 20], "I am foo3", 2, 0.1) new_bucket1 = FooBucket('new_bucket1', [foo3]) self.foofile.add_bucket(new_bucket1) # write to same file with same manager, overwriting existing file with HDF5IO(self.path, manager=manager, mode='w') as io: io.write(self.foofile) # check that new bucket was written with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: read_foofile = io.read() self.assertEqual(len(read_foofile.buckets), 2) self.assertContainerEqual(read_foofile.buckets['new_bucket1'], new_bucket1) def test_write_add_append_bucket(self): """Test appending a container to a file.""" manager = _get_manager() with HDF5IO(self.path, manager=manager, mode='w') as io: io.write(self.foofile) foo3 = Foo('foo3', [10, 20], "I am foo3", 2, 0.1) new_bucket1 = FooBucket('new_bucket1', [foo3]) # append to same file with same manager, overwriting existing file with HDF5IO(self.path, manager=manager, mode='a') as io: read_foofile = io.read() # append to read container and call write read_foofile.add_bucket(new_bucket1) io.write(read_foofile) # check that new bucket was written with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: read_foofile = io.read() self.assertEqual(len(read_foofile.buckets), 2) self.assertContainerEqual(read_foofile.buckets['new_bucket1'], new_bucket1) def test_write_add_append_double_write(self): """Test using the same IO object to append a container to a file twice.""" manager = _get_manager() with HDF5IO(self.path, manager=manager, mode='w') as io: io.write(self.foofile) foo3 = Foo('foo3', [10, 20], "I am foo3", 2, 0.1) new_bucket1 = FooBucket('new_bucket1', [foo3]) foo4 = Foo('foo4', [10, 20], "I am foo4", 2, 0.1) new_bucket2 = FooBucket('new_bucket2', [foo4]) # append to same file with same manager, overwriting existing file with HDF5IO(self.path, manager=manager, mode='a') as io: read_foofile = io.read() # append to read container and call write read_foofile.add_bucket(new_bucket1) io.write(read_foofile) # append to read container again and call write again read_foofile.add_bucket(new_bucket2) io.write(read_foofile) # check that both new buckets were written with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: read_foofile = io.read() self.assertEqual(len(read_foofile.buckets), 3) self.assertContainerEqual(read_foofile.buckets['new_bucket1'], new_bucket1) self.assertContainerEqual(read_foofile.buckets['new_bucket2'], new_bucket2) class HDF5IOMultiFileTest(TestCase): """Tests for h5tools IO tools""" def setUp(self): numfiles = 3 self.paths = [get_temp_filepath() for i in range(numfiles)] # On Windows h5py cannot truncate an open file in write mode. # The temp file will be closed before h5py truncates it # and will be removed during the tearDown step. self.io = [HDF5IO(i, mode='a', manager=_get_manager()) for i in self.paths] self.f = [i._file for i in self.io] def tearDown(self): # Close all the files for i in self.io: i.close() del(i) self.io = None self.f = None # Make sure the files have been deleted for tf in self.paths: try: os.remove(tf) except OSError: pass def test_copy_file_with_external_links(self): # Create the first file foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) bucket1 = FooBucket('bucket1', [foo1]) foofile1 = FooFile(buckets=[bucket1]) # Write the first file self.io[0].write(foofile1) # Create the second file read_foofile1 = self.io[0].read() foo2 = Foo('foo2', read_foofile1.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 34, 6.28) bucket2 = FooBucket('bucket2', [foo2]) foofile2 = FooFile(buckets=[bucket2]) # Write the second file self.io[1].write(foofile2) self.io[1].close() self.io[0].close() # Don't forget to close the first file too # Copy the file self.io[2].close() with self.assertWarns(DeprecationWarning): HDF5IO.copy_file(source_filename=self.paths[1], dest_filename=self.paths[2], expand_external=True, expand_soft=False, expand_refs=False) # Test that everything is working as expected # Confirm that our original data file is correct f1 = File(self.paths[0], 'r') self.assertIsInstance(f1.get('/buckets/bucket1/foo_holder/foo1/my_data', getlink=True), HardLink) # Confirm that we successfully created and External Link in our second file f2 = File(self.paths[1], 'r') self.assertIsInstance(f2.get('/buckets/bucket2/foo_holder/foo2/my_data', getlink=True), ExternalLink) # Confirm that we successfully resolved the External Link when we copied our second file f3 = File(self.paths[2], 'r') self.assertIsInstance(f3.get('/buckets/bucket2/foo_holder/foo2/my_data', getlink=True), HardLink) class TestCloseLinks(TestCase): def setUp(self): self.path1 = get_temp_filepath() self.path2 = get_temp_filepath() def tearDown(self): if self.path1 is not None: os.remove(self.path1) # linked file may not be closed if self.path2 is not None: os.remove(self.path2) def test_close_file_with_links(self): # Create the first file foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) bucket1 = FooBucket('bucket1', [foo1]) foofile1 = FooFile(buckets=[bucket1]) # Write the first file with HDF5IO(self.path1, mode='w', manager=_get_manager()) as io: io.write(foofile1) # Create the second file manager = _get_manager() # use the same manager for read and write so that links work with HDF5IO(self.path1, mode='r', manager=manager) as read_io: read_foofile1 = read_io.read() foofile2 = FooFile(foo_link=read_foofile1.buckets['bucket1'].foos['foo1']) # cross-file link # Write the second file with HDF5IO(self.path2, mode='w', manager=manager) as write_io: write_io.write(foofile2) with HDF5IO(self.path2, mode='a', manager=_get_manager()) as new_io1: read_foofile2 = new_io1.read() # keep reference to container in memory self.assertTrue(read_foofile2.foo_link.my_data) new_io1.close_linked_files() self.assertFalse(read_foofile2.foo_link.my_data) # should be able to reopen both files with HDF5IO(self.path1, mode='a', manager=_get_manager()) as new_io3: new_io3.read() def test_double_close_file_with_links(self): # Create the first file foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) bucket1 = FooBucket('bucket1', [foo1]) foofile1 = FooFile(buckets=[bucket1]) # Write the first file with HDF5IO(self.path1, mode='w', manager=_get_manager()) as io: io.write(foofile1) # Create the second file manager = _get_manager() # use the same manager for read and write so that links work with HDF5IO(self.path1, mode='r', manager=manager) as read_io: read_foofile1 = read_io.read() foofile2 = FooFile(foo_link=read_foofile1.buckets['bucket1'].foos['foo1']) # cross-file link # Write the second file with HDF5IO(self.path2, mode='w', manager=manager) as write_io: write_io.write(foofile2) with HDF5IO(self.path2, mode='a', manager=_get_manager()) as new_io1: read_foofile2 = new_io1.read() # keep reference to container in memory read_foofile2.foo_link.my_data.file.close() # explicitly close the file from the h5dataset self.assertFalse(read_foofile2.foo_link.my_data) new_io1.close_linked_files() # make sure this does not fail because the linked-to file is already closed class HDF5IOInitNoFileTest(TestCase): """ Test if file does not exist, init with mode (r, r+) throws error, all others succeed """ def test_init_no_file_r(self): self.path = "test_init_nofile_r.h5" with self.assertRaisesWith(UnsupportedOperation, "Unable to open file %s in 'r' mode. File does not exist." % self.path): HDF5IO(self.path, mode='r') def test_init_no_file_rplus(self): self.path = "test_init_nofile_rplus.h5" with self.assertRaisesWith(UnsupportedOperation, "Unable to open file %s in 'r+' mode. File does not exist." % self.path): HDF5IO(self.path, mode='r+') def test_init_no_file_ok(self): # test that no errors are thrown modes = ('w', 'w-', 'x', 'a') for m in modes: self.path = "test_init_nofile.h5" with HDF5IO(self.path, mode=m): pass if os.path.exists(self.path): os.remove(self.path) class HDF5IOInitFileExistsTest(TestCase): """ Test if file exists, init with mode w-/x throws error, all others succeed """ def setUp(self): self.path = get_temp_filepath() temp_io = HDF5IO(self.path, mode='w') temp_io.close() self.io = None def tearDown(self): if self.io is not None: self.io.close() del(self.io) if os.path.exists(self.path): os.remove(self.path) def test_init_wminus_file_exists(self): with self.assertRaisesWith(UnsupportedOperation, "Unable to open file %s in 'w-' mode. File already exists." % self.path): self.io = HDF5IO(self.path, mode='w-') def test_init_x_file_exists(self): with self.assertRaisesWith(UnsupportedOperation, "Unable to open file %s in 'x' mode. File already exists." % self.path): self.io = HDF5IO(self.path, mode='x') def test_init_file_exists_ok(self): # test that no errors are thrown modes = ('r', 'r+', 'w', 'a') for m in modes: with HDF5IO(self.path, mode=m): pass class HDF5IOReadNoDataTest(TestCase): """ Test if file exists and there is no data, read with mode (r, r+, a) throws error """ def setUp(self): self.path = get_temp_filepath() temp_io = HDF5IO(self.path, mode='w') temp_io.close() self.io = None def tearDown(self): if self.io is not None: self.io.close() del(self.io) if os.path.exists(self.path): os.remove(self.path) def test_read_no_data_r(self): self.io = HDF5IO(self.path, mode='r') with self.assertRaisesWith(UnsupportedOperation, "Cannot read data from file %s in mode 'r'. There are no values." % self.path): self.io.read() def test_read_no_data_rplus(self): self.io = HDF5IO(self.path, mode='r+') with self.assertRaisesWith(UnsupportedOperation, "Cannot read data from file %s in mode 'r+'. There are no values." % self.path): self.io.read() def test_read_no_data_a(self): self.io = HDF5IO(self.path, mode='a') with self.assertRaisesWith(UnsupportedOperation, "Cannot read data from file %s in mode 'a'. There are no values." % self.path): self.io.read() class HDF5IOReadData(TestCase): """ Test if file exists and there is no data, read in mode (r, r+, a) is ok and read in mode w throws error """ def setUp(self): self.path = get_temp_filepath() foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) bucket1 = FooBucket('bucket1', [foo1]) self.foofile1 = FooFile(buckets=[bucket1]) with HDF5IO(self.path, manager=_get_manager(), mode='w') as temp_io: temp_io.write(self.foofile1) self.io = None def tearDown(self): if self.io is not None: self.io.close() del(self.io) if os.path.exists(self.path): os.remove(self.path) def test_read_file_ok(self): modes = ('r', 'r+', 'a') for m in modes: with HDF5IO(self.path, manager=_get_manager(), mode=m) as io: io.read() def test_read_file_w(self): with HDF5IO(self.path, manager=_get_manager(), mode='w') as io: with self.assertRaisesWith(UnsupportedOperation, "Cannot read from file %s in mode 'w'. Please use mode 'r', 'r+', or 'a'." % self.path): read_foofile1 = io.read() self.assertListEqual(self.foofile1.buckets['bucket1'].foos['foo1'].my_data, read_foofile1.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) class HDF5IOReadBuilderClosed(TestCase): """Test if file exists but is closed, then read_builder raises an error. """ def setUp(self): self.path = get_temp_filepath() temp_io = HDF5IO(self.path, mode='w') temp_io.close() self.io = None def tearDown(self): if self.io is not None: self.io.close() del(self.io) if os.path.exists(self.path): os.remove(self.path) def test_read_closed(self): self.io = HDF5IO(self.path, mode='r') self.io.close() msg = "Cannot read data from closed HDF5 file '%s'" % self.path with self.assertRaisesWith(UnsupportedOperation, msg): self.io.read_builder() class HDF5IOWriteNoFile(TestCase): """ Test if file does not exist, write in mode (w, w-, x, a) is ok """ def setUp(self): foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) bucket1 = FooBucket('bucket1', [foo1]) self.foofile1 = FooFile(buckets=[bucket1]) self.path = 'test_write_nofile.h5' def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def test_write_no_file_w_ok(self): self.__write_file('w') def test_write_no_file_wminus_ok(self): self.__write_file('w-') def test_write_no_file_x_ok(self): self.__write_file('x') def test_write_no_file_a_ok(self): self.__write_file('a') def __write_file(self, mode): with HDF5IO(self.path, manager=_get_manager(), mode=mode) as io: io.write(self.foofile1) with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: read_foofile = io.read() self.assertListEqual(self.foofile1.buckets['bucket1'].foos['foo1'].my_data, read_foofile.buckets['bucket1'].foos['foo1'].my_data[:].tolist()) class HDF5IOWriteFileExists(TestCase): """ Test if file exists, write in mode (r+, w, a) is ok and write in mode r throws error """ def setUp(self): self.path = get_temp_filepath() foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) bucket1 = FooBucket('bucket1', [foo1]) self.foofile1 = FooFile(buckets=[bucket1]) foo2 = Foo('foo2', [0, 1, 2, 3, 4], "I am foo2", 17, 3.14) bucket2 = FooBucket('bucket2', [foo2]) self.foofile2 = FooFile(buckets=[bucket2]) with HDF5IO(self.path, manager=_get_manager(), mode='w') as io: io.write(self.foofile1) self.io = None def tearDown(self): if self.io is not None: self.io.close() del(self.io) if os.path.exists(self.path): os.remove(self.path) def test_write_rplus(self): with HDF5IO(self.path, manager=_get_manager(), mode='r+') as io: # even though foofile1 and foofile2 have different names, writing a # root object into a file that already has a root object, in r+ mode # should throw an error with self.assertRaisesWith(ValueError, "Unable to create group (name already exists)"): io.write(self.foofile2) def test_write_a(self): with HDF5IO(self.path, manager=_get_manager(), mode='a') as io: # even though foofile1 and foofile2 have different names, writing a # root object into a file that already has a root object, in a mode # should throw an error with self.assertRaisesWith(ValueError, "Unable to create group (name already exists)"): io.write(self.foofile2) def test_write_w(self): # mode 'w' should overwrite contents of file with HDF5IO(self.path, manager=_get_manager(), mode='w') as io: io.write(self.foofile2) with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: read_foofile = io.read() self.assertListEqual(self.foofile2.buckets['bucket2'].foos['foo2'].my_data, read_foofile.buckets['bucket2'].foos['foo2'].my_data[:].tolist()) def test_write_r(self): with HDF5IO(self.path, manager=_get_manager(), mode='r') as io: with self.assertRaisesWith(UnsupportedOperation, ("Cannot write to file %s in mode 'r'. " "Please use mode 'r+', 'w', 'w-', 'x', or 'a'") % self.path): io.write(self.foofile2) class TestWritten(TestCase): def setUp(self): self.manager = _get_manager() self.path = get_temp_filepath() foo1 = Foo('foo1', [0, 1, 2, 3, 4], "I am foo1", 17, 3.14) foo2 = Foo('foo2', [5, 6, 7, 8, 9], "I am foo2", 34, 6.28) foobucket = FooBucket('bucket1', [foo1, foo2]) self.foofile = FooFile([foobucket]) def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def test_set_written_on_write(self): """Test that write_builder changes the written flag of the builder and its children from False to True.""" with HDF5IO(self.path, manager=self.manager, mode='w') as io: builder = self.manager.build(container=self.foofile, source=self.path) self.assertFalse(io.get_written(builder)) self._check_written_children(io, builder, False) io.write_builder(builder) self.assertTrue(io.get_written(builder)) self._check_written_children(io, builder, True) def _check_written_children(self, io, builder, val): """Test whether the io object has the written flag of the child builders set to val.""" for group_bldr in builder.groups.values(): self.assertEqual(io.get_written(group_bldr), val) self._check_written_children(io, group_bldr, val) for dset_bldr in builder.datasets.values(): self.assertEqual(io.get_written(dset_bldr), val) for link_bldr in builder.links.values(): self.assertEqual(io.get_written(link_bldr), val) class H5DataIOValid(TestCase): def setUp(self): self.paths = [get_temp_filepath(), ] self.foo1 = Foo('foo1', H5DataIO([1, 2, 3, 4, 5]), "I am foo1", 17, 3.14) bucket1 = FooBucket('bucket1', [self.foo1]) foofile1 = FooFile(buckets=[bucket1]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as io: io.write(foofile1) def tearDown(self): for path in self.paths: if os.path.exists(path): os.remove(path) def test_valid(self): self.assertTrue(self.foo1.my_data.valid) def test_read_valid(self): """Test that h5py.H5Dataset.id.valid works as expected""" with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as io: read_foofile1 = io.read() self.assertTrue(read_foofile1.buckets['bucket1'].foos['foo1'].my_data.id.valid) self.assertFalse(read_foofile1.buckets['bucket1'].foos['foo1'].my_data.id.valid) def test_link(self): """Test that wrapping of linked data within H5DataIO """ with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as io: read_foofile1 = io.read() self.foo2 = Foo('foo2', H5DataIO(data=read_foofile1.buckets['bucket1'].foos['foo1'].my_data), "I am foo2", 17, 3.14) bucket2 = FooBucket('bucket2', [self.foo2]) foofile2 = FooFile(buckets=[bucket2]) self.paths.append(get_temp_filepath()) with HDF5IO(self.paths[1], manager=_get_manager(), mode='w') as io: io.write(foofile2) self.assertTrue(self.foo2.my_data.valid) # test valid self.assertEqual(len(self.foo2.my_data), 5) # test len self.assertEqual(self.foo2.my_data.shape, (5,)) # test getattr with shape self.assertTrue(np.array_equal(np.array(self.foo2.my_data), [1, 2, 3, 4, 5])) # test array conversion # test loop through iterable match = [1, 2, 3, 4, 5] for (i, j) in zip(self.foo2.my_data, match): self.assertEqual(i, j) # test iterator my_iter = iter(self.foo2.my_data) self.assertEqual(next(my_iter), 1) # foo2.my_data dataset is now closed self.assertFalse(self.foo2.my_data.valid) with self.assertRaisesWith(InvalidDataIOError, "Cannot get length of data. Data is not valid."): len(self.foo2.my_data) with self.assertRaisesWith(InvalidDataIOError, "Cannot get attribute 'shape' of data. Data is not valid."): self.foo2.my_data.shape with self.assertRaisesWith(InvalidDataIOError, "Cannot convert data to array. Data is not valid."): np.array(self.foo2.my_data) with self.assertRaisesWith(InvalidDataIOError, "Cannot iterate on data. Data is not valid."): for i in self.foo2.my_data: pass with self.assertRaisesWith(InvalidDataIOError, "Cannot iterate on data. Data is not valid."): iter(self.foo2.my_data) # re-open the file with the data linking to other file (still closed) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as io: read_foofile2 = io.read() read_foo2 = read_foofile2.buckets['bucket2'].foos['foo2'] # note that read_foo2 dataset does not have an attribute 'valid' self.assertEqual(len(read_foo2.my_data), 5) # test len self.assertEqual(read_foo2.my_data.shape, (5,)) # test getattr with shape self.assertTrue(np.array_equal(np.array(read_foo2.my_data), [1, 2, 3, 4, 5])) # test array conversion # test loop through iterable match = [1, 2, 3, 4, 5] for (i, j) in zip(read_foo2.my_data, match): self.assertEqual(i, j) # test iterator my_iter = iter(read_foo2.my_data) self.assertEqual(next(my_iter), 1) class TestReadLink(TestCase): def setUp(self): self.target_path = get_temp_filepath() self.link_path = get_temp_filepath() root1 = GroupBuilder(name='root') subgroup = GroupBuilder(name='test_group') root1.set_group(subgroup) dataset = DatasetBuilder('test_dataset', data=[1, 2, 3, 4]) subgroup.set_dataset(dataset) root2 = GroupBuilder(name='root') link_group = LinkBuilder(subgroup, 'link_to_test_group') root2.set_link(link_group) link_dataset = LinkBuilder(dataset, 'link_to_test_dataset') root2.set_link(link_dataset) with HDF5IO(self.target_path, manager=_get_manager(), mode='w') as io: io.write_builder(root1) root1.source = self.target_path with HDF5IO(self.link_path, manager=_get_manager(), mode='w') as io: io.write_builder(root2) root2.source = self.link_path self.ios = [] def tearDown(self): for io in self.ios: io.close_linked_files() if os.path.exists(self.target_path): os.remove(self.target_path) if os.path.exists(self.link_path): os.remove(self.link_path) def test_set_link_loc(self): """ Test that Builder location is set when it is read as a link """ read_io = HDF5IO(self.link_path, manager=_get_manager(), mode='r') self.ios.append(read_io) # store IO object for closing in tearDown bldr = read_io.read_builder() self.assertEqual(bldr['link_to_test_group'].builder.location, '/') self.assertEqual(bldr['link_to_test_dataset'].builder.location, '/test_group') read_io.close() def test_link_to_link(self): """ Test that link to link gets written and read properly """ link_to_link_path = get_temp_filepath() read_io1 = HDF5IO(self.link_path, manager=_get_manager(), mode='r') self.ios.append(read_io1) # store IO object for closing in tearDown bldr1 = read_io1.read_builder() root3 = GroupBuilder(name='root') link = LinkBuilder(bldr1['link_to_test_group'].builder, 'link_to_link') root3.set_link(link) with HDF5IO(link_to_link_path, manager=_get_manager(), mode='w') as io: io.write_builder(root3) read_io1.close() read_io2 = HDF5IO(link_to_link_path, manager=_get_manager(), mode='r') self.ios.append(read_io2) bldr2 = read_io2.read_builder() self.assertEqual(bldr2['link_to_link'].builder.source, self.target_path) read_io2.close() def test_broken_link(self): """Test that opening a file with a broken link raises a warning but is still readable.""" os.remove(self.target_path) # with self.assertWarnsWith(BrokenLinkWarning, '/link_to_test_dataset'): # can't check both warnings with self.assertWarnsWith(BrokenLinkWarning, '/link_to_test_group'): with HDF5IO(self.link_path, manager=_get_manager(), mode='r') as read_io: bldr = read_io.read_builder() self.assertDictEqual(bldr.links, {}) def test_broken_linked_data(self): """Test that opening a file with a broken link raises a warning but is still readable.""" manager = _get_manager() with HDF5IO(self.target_path, manager=manager, mode='r') as read_io: read_root = read_io.read_builder() read_dataset_data = read_root.groups['test_group'].datasets['test_dataset'].data with HDF5IO(self.link_path, manager=manager, mode='w') as write_io: root2 = GroupBuilder(name='root') dataset = DatasetBuilder(name='link_to_test_dataset', data=read_dataset_data) root2.set_dataset(dataset) write_io.write_builder(root2, link_data=True) os.remove(self.target_path) with self.assertWarnsWith(BrokenLinkWarning, '/link_to_test_dataset'): with HDF5IO(self.link_path, manager=_get_manager(), mode='r') as read_io: bldr = read_io.read_builder() self.assertDictEqual(bldr.links, {}) class TestBuildWriteLinkToLink(TestCase): def setUp(self): self.paths = [ get_temp_filepath(), get_temp_filepath(), get_temp_filepath() ] self.ios = [] def tearDown(self): for io in self.ios: io.close_linked_files() for p in self.paths: if os.path.exists(p): os.remove(p) def test_external_link_to_external_link(self): """Test writing a file with external links to external links.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) manager = _get_manager() with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: read_foofile = read_io.read() # make external link to existing group foofile2 = FooFile(foo_link=read_foofile.buckets['bucket1'].foos['foo1']) with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: write_io.write(foofile2) manager = _get_manager() with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() foofile3 = FooFile(foo_link=read_foofile2.foo_link) # make external link to external link with HDF5IO(self.paths[2], manager=manager, mode='w') as write_io: write_io.write(foofile3) with HDF5IO(self.paths[2], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile3 = read_io.read() self.assertEqual(read_foofile3.foo_link.container_source, self.paths[0]) def test_external_link_to_soft_link(self): """Test writing a file with external links to external links.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket], foo_link=foo1) # create soft link with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) manager = _get_manager() with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: read_foofile = read_io.read() foofile2 = FooFile(foo_link=read_foofile.foo_link) # make external link to existing soft link with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: write_io.write(foofile2) manager = _get_manager() with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() foofile3 = FooFile(foo_link=read_foofile2.foo_link) # make external link to external link with HDF5IO(self.paths[2], manager=manager, mode='w') as write_io: write_io.write(foofile3) with HDF5IO(self.paths[2], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile3 = read_io.read() self.assertEqual(read_foofile3.foo_link.container_source, self.paths[0]) class TestLinkData(TestCase): def setUp(self): self.target_path = get_temp_filepath() self.link_path = get_temp_filepath() root1 = GroupBuilder(name='root') subgroup = GroupBuilder(name='test_group') root1.set_group(subgroup) dataset = DatasetBuilder('test_dataset', data=[1, 2, 3, 4]) subgroup.set_dataset(dataset) with HDF5IO(self.target_path, manager=_get_manager(), mode='w') as io: io.write_builder(root1) def tearDown(self): if os.path.exists(self.target_path): os.remove(self.target_path) if os.path.exists(self.link_path): os.remove(self.link_path) def test_link_data_true(self): """Test that the argument link_data=True for write_builder creates an external link.""" manager = _get_manager() with HDF5IO(self.target_path, manager=manager, mode='r') as read_io: read_root = read_io.read_builder() read_dataset_data = read_root.groups['test_group'].datasets['test_dataset'].data with HDF5IO(self.link_path, manager=manager, mode='w') as write_io: root2 = GroupBuilder(name='root') dataset = DatasetBuilder(name='link_to_test_dataset', data=read_dataset_data) root2.set_dataset(dataset) write_io.write_builder(root2, link_data=True) with File(self.link_path, mode='r') as f: self.assertIsInstance(f.get('link_to_test_dataset', getlink=True), ExternalLink) def test_link_data_false(self): """Test that the argument link_data=False for write_builder copies the data.""" manager = _get_manager() with HDF5IO(self.target_path, manager=manager, mode='r') as read_io: read_root = read_io.read_builder() read_dataset_data = read_root.groups['test_group'].datasets['test_dataset'].data with HDF5IO(self.link_path, manager=manager, mode='w') as write_io: root2 = GroupBuilder(name='root') dataset = DatasetBuilder(name='link_to_test_dataset', data=read_dataset_data) root2.set_dataset(dataset) write_io.write_builder(root2, link_data=False) with File(self.link_path, mode='r') as f: self.assertFalse(isinstance(f.get('link_to_test_dataset', getlink=True), ExternalLink)) self.assertListEqual(f.get('link_to_test_dataset')[:].tolist(), [1, 2, 3, 4]) class TestLoadNamespaces(TestCase): def setUp(self): self.manager = _get_manager() self.path = get_temp_filepath() container = FooFile() with HDF5IO(self.path, manager=self.manager, mode='w') as io: io.write(container) def tearDown(self): if os.path.exists(self.path): os.remove(self.path) def test_load_namespaces_none_version(self): """Test that reading a file with a cached namespace and None version works but raises a warning.""" # make the file have group name "None" instead of "0.1.0" (namespace version is used as group name) # and set the version key to "None" with h5py.File(self.path, mode='r+') as f: # rename the group f.move('/specifications/test_core/0.1.0', '/specifications/test_core/None') # replace the namespace dataset with a serialized dict with the version key set to 'None' new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core",' '"version":"None"}]}') f['/specifications/test_core/None/namespace'][()] = new_ns # load the namespace from file ns_catalog = NamespaceCatalog() msg = "Loaded namespace '%s' is unversioned. Please notify the extension author." % CORE_NAMESPACE with self.assertWarnsWith(UserWarning, msg): HDF5IO.load_namespaces(ns_catalog, self.path) def test_load_namespaces_unversioned(self): """Test that reading a file with a cached, unversioned version works but raises a warning.""" # make the file have group name "unversioned" instead of "0.1.0" (namespace version is used as group name) # and remove the version key with h5py.File(self.path, mode='r+') as f: # rename the group f.move('/specifications/test_core/0.1.0', '/specifications/test_core/unversioned') # replace the namespace dataset with a serialized dict without the version key new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]}') f['/specifications/test_core/unversioned/namespace'][()] = new_ns # load the namespace from file ns_catalog = NamespaceCatalog() msg = ("Loaded namespace '%s' is missing the required key 'version'. Version will be set to " "'%s'. Please notify the extension author." % (CORE_NAMESPACE, SpecNamespace.UNVERSIONED)) with self.assertWarnsWith(UserWarning, msg): HDF5IO.load_namespaces(ns_catalog, self.path) def test_load_namespaces_path(self): """Test that loading namespaces given a path is OK and returns the correct dictionary.""" ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, self.path) self.assertEqual(d, {'test_core': {}}) # test_core has no dependencies def test_load_namespaces_no_path_no_file(self): """Test that loading namespaces without a path or file raises an error.""" ns_catalog = NamespaceCatalog() msg = "Either the 'path' or 'file' argument must be supplied." with self.assertRaisesWith(ValueError, msg): HDF5IO.load_namespaces(ns_catalog) def test_load_namespaces_file_no_path(self): """ Test that loading namespaces from an h5py.File not backed by a file on disk is OK and does not close the file. """ with open(self.path, 'rb') as raw_file: buffer = BytesIO(raw_file.read()) file_obj = h5py.File(buffer, 'r') ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, file=file_obj) self.assertTrue(file_obj.__bool__()) # check file object is still open self.assertEqual(d, {'test_core': {}}) file_obj.close() def test_load_namespaces_file_path_matched(self): """Test that loading namespaces given an h5py.File and path is OK and does not close the file.""" with h5py.File(self.path, 'r') as file_obj: ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, path=self.path, file=file_obj) self.assertTrue(file_obj.__bool__()) # check file object is still open self.assertEqual(d, {'test_core': {}}) def test_load_namespaces_file_path_mismatched(self): """Test that loading namespaces given an h5py.File and path that are mismatched raises an error.""" with h5py.File(self.path, 'r') as file_obj: ns_catalog = NamespaceCatalog() msg = "You argued 'different_path' as this object's path, but supplied a file with filename: %s" % self.path with self.assertRaisesWith(ValueError, msg): HDF5IO.load_namespaces(ns_catalog, path='different_path', file=file_obj) def test_load_namespaces_with_pathlib_path(self): """Test that loading a namespace using a valid pathlib Path is OK and returns the correct dictionary.""" pathlib_path = Path(self.path) ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, pathlib_path) self.assertEqual(d, {'test_core': {}}) # test_core has no dependencies def test_load_namespaces_with_dependencies(self): """Test loading namespaces where one includes another.""" class MyFoo(Container): pass myfoo_spec = GroupSpec(doc="A MyFoo", data_type_def='MyFoo', data_type_inc='Foo') spec_catalog = SpecCatalog() name = 'test_core2' namespace = SpecNamespace( doc='a test namespace', name=name, schema=[{'source': 'test2.yaml', 'namespace': 'test_core'}], # depends on test_core version='0.1.0', catalog=spec_catalog ) spec_catalog.register_spec(myfoo_spec, 'test2.yaml') namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(name, namespace) type_map = TypeMap(namespace_catalog) type_map.register_container_type(name, 'MyFoo', MyFoo) type_map.merge(self.manager.type_map, ns_catalog=True) manager = BuildManager(type_map) container = MyFoo(name='myfoo') with HDF5IO(self.path, manager=manager, mode='a') as io: # append to file io.write(container) ns_catalog = NamespaceCatalog() d = HDF5IO.load_namespaces(ns_catalog, self.path) self.assertEqual(d, {'test_core': {}, 'test_core2': {'test_core': ('Foo', 'FooBucket', 'FooFile')}}) def test_load_namespaces_no_specloc(self): """Test loading namespaces where the file does not contain a SPEC_LOC_ATTR.""" # delete the spec location attribute from the file with h5py.File(self.path, mode='r+') as f: del f.attrs[SPEC_LOC_ATTR] # load the namespace from file ns_catalog = NamespaceCatalog() msg = "No cached namespaces found in %s" % self.path with self.assertWarnsWith(UserWarning, msg): ret = HDF5IO.load_namespaces(ns_catalog, self.path) self.assertDictEqual(ret, {}) def test_load_namespaces_resolve_custom_deps(self): """Test that reading a file with a cached namespace and different def/inc keys works.""" # Setup all the data we need foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.path, manager=self.manager, mode='w') as io: io.write(foofile) with h5py.File(self.path, mode='r+') as f: # add two types where one extends the other and overrides an attribute # check that the inherited attribute resolves correctly despite having a different def/inc key than those # used in the namespace catalog added_types = (',{"data_type_def":"BigFoo","data_type_inc":"Foo","doc":"doc","attributes":[' '{"name":"my_attr","dtype":"text","doc":"an attr"}]},' '{"data_type_def":"BiggerFoo","data_type_inc":"BigFoo","doc":"doc"}]}') old_test_source = f['/specifications/test_core/0.1.0/test'] # strip the ]} from end, then add to groups if H5PY_3: # string datasets are returned as bytes old_test_source[()] = old_test_source[()][0:-2].decode('utf-8') + added_types else: old_test_source[()] = old_test_source[()][0:-2] + added_types new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[' '{"namespace":"test_core","my_data_types":["Foo"]},' '{"source":"test-ext.extensions"}' '],"name":"test-ext","version":"0.1.0"}]}') f.create_dataset('/specifications/test-ext/0.1.0/namespace', data=new_ns) new_ext = '{"groups":[{"my_data_type_def":"FooExt","my_data_type_inc":"Foo","doc":"doc"}]}' f.create_dataset('/specifications/test-ext/0.1.0/test-ext.extensions', data=new_ext) # load the namespace from file ns_catalog = NamespaceCatalog(CustomGroupSpec, CustomDatasetSpec, CustomSpecNamespace) namespace_deps = HDF5IO.load_namespaces(ns_catalog, self.path) # test that the dependencies are correct expected = ('Foo',) self.assertTupleEqual((namespace_deps['test-ext']['test_core']), expected) # test that the types are loaded types = ns_catalog.get_types('test-ext.extensions') expected = ('FooExt',) self.assertTupleEqual(types, expected) # test that the def_key is updated for test-ext ns foo_ext_spec = ns_catalog.get_spec('test-ext', 'FooExt') self.assertTrue('my_data_type_def' in foo_ext_spec) self.assertTrue('my_data_type_inc' in foo_ext_spec) # test that the data_type_def is replaced with my_data_type_def for test_core ns bigger_foo_spec = ns_catalog.get_spec('test_core', 'BiggerFoo') self.assertTrue('my_data_type_def' in bigger_foo_spec) self.assertTrue('my_data_type_inc' in bigger_foo_spec) # test that my_attr is properly inherited in BiggerFoo from BigFoo and attr1, attr3 are inherited from Foo self.assertTrue(len(bigger_foo_spec.attributes) == 3) class TestGetNamespaces(TestCase): def create_test_namespace(self, name, version): file_spec = GroupSpec(doc="A FooFile", data_type_def='FooFile') spec_catalog = SpecCatalog() namespace = SpecNamespace( doc='a test namespace', name=name, schema=[{'source': 'test.yaml'}], version=version, catalog=spec_catalog ) spec_catalog.register_spec(file_spec, 'test.yaml') return namespace def write_test_file(self, name, version, mode): namespace = self.create_test_namespace(name, version) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(name, namespace) type_map = TypeMap(namespace_catalog) type_map.register_container_type(name, 'FooFile', FooFile) manager = BuildManager(type_map) with HDF5IO(self.path, manager=manager, mode=mode) as io: io.write(self.container) def setUp(self): self.path = get_temp_filepath() self.container = FooFile() def tearDown(self): if os.path.exists(self.path): os.remove(self.path) # see other tests for path & file match/mismatch testing in TestLoadNamespaces def test_get_namespaces_with_path(self): """Test getting namespaces given a path.""" self.write_test_file('test_core', '0.1.0', 'w') ret = HDF5IO.get_namespaces(path=self.path) self.assertEqual(ret, {'test_core': '0.1.0'}) def test_get_namespaces_with_file(self): """Test getting namespaces given a file object.""" self.write_test_file('test_core', '0.1.0', 'w') with File(self.path, 'r') as f: ret = HDF5IO.get_namespaces(file=f) self.assertEqual(ret, {'test_core': '0.1.0'}) self.assertTrue(f.__bool__()) # check file object is still open def test_get_namespaces_different_versions(self): """Test getting namespaces with multiple versions given a path.""" # write file with spec with smaller version string self.write_test_file('test_core', '0.0.10', 'w') # append to file with spec with larger version string self.write_test_file('test_core', '0.1.0', 'a') ret = HDF5IO.get_namespaces(path=self.path) self.assertEqual(ret, {'test_core': '0.1.0'}) def test_get_namespaces_multiple_namespaces(self): """Test getting multiple namespaces given a path.""" self.write_test_file('test_core1', '0.0.10', 'w') self.write_test_file('test_core2', '0.1.0', 'a') ret = HDF5IO.get_namespaces(path=self.path) self.assertEqual(ret, {'test_core1': '0.0.10', 'test_core2': '0.1.0'}) def test_get_namespaces_none_version(self): """Test getting namespaces where file has one None-versioned namespace.""" self.write_test_file('test_core', '0.1.0', 'w') # make the file have group name "None" instead of "0.1.0" (namespace version is used as group name) # and set the version key to "None" with h5py.File(self.path, mode='r+') as f: # rename the group f.move('/specifications/test_core/0.1.0', '/specifications/test_core/None') # replace the namespace dataset with a serialized dict with the version key set to 'None' new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core",' '"version":"None"}]}') f['/specifications/test_core/None/namespace'][()] = new_ns ret = HDF5IO.get_namespaces(path=self.path) self.assertEqual(ret, {'test_core': 'None'}) def test_get_namespaces_none_and_other_version(self): """Test getting namespaces file has a namespace with a normal version and an 'None" version.""" self.write_test_file('test_core', '0.1.0', 'w') # make the file have group name "None" instead of "0.1.0" (namespace version is used as group name) # and set the version key to "None" with h5py.File(self.path, mode='r+') as f: # rename the group f.move('/specifications/test_core/0.1.0', '/specifications/test_core/None') # replace the namespace dataset with a serialized dict with the version key set to 'None' new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core",' '"version":"None"}]}') f['/specifications/test_core/None/namespace'][()] = new_ns # append to file with spec with a larger version string self.write_test_file('test_core', '0.2.0', 'a') ret = HDF5IO.get_namespaces(path=self.path) self.assertEqual(ret, {'test_core': '0.2.0'}) def test_get_namespaces_unversioned(self): """Test getting namespaces where file has one unversioned namespace.""" self.write_test_file('test_core', '0.1.0', 'w') # make the file have group name "unversioned" instead of "0.1.0" (namespace version is used as group name) with h5py.File(self.path, mode='r+') as f: # rename the group f.move('/specifications/test_core/0.1.0', '/specifications/test_core/unversioned') # replace the namespace dataset with a serialized dict without the version key new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]}') f['/specifications/test_core/unversioned/namespace'][()] = new_ns ret = HDF5IO.get_namespaces(path=self.path) self.assertEqual(ret, {'test_core': 'unversioned'}) def test_get_namespaces_unversioned_and_other(self): """Test getting namespaces file has a namespace with a normal version and an 'unversioned" version.""" self.write_test_file('test_core', '0.1.0', 'w') # make the file have group name "unversioned" instead of "0.1.0" (namespace version is used as group name) with h5py.File(self.path, mode='r+') as f: # rename the group f.move('/specifications/test_core/0.1.0', '/specifications/test_core/unversioned') # replace the namespace dataset with a serialized dict without the version key new_ns = ('{"namespaces":[{"doc":"a test namespace","schema":[{"source":"test"}],"name":"test_core"}]}') f['/specifications/test_core/unversioned/namespace'][()] = new_ns # append to file with spec with a larger version string self.write_test_file('test_core', '0.2.0', 'a') ret = HDF5IO.get_namespaces(path=self.path) self.assertEqual(ret, {'test_core': '0.2.0'}) def test_get_namespaces_no_specloc(self): """Test getting namespaces where the file does not contain a SPEC_LOC_ATTR.""" self.write_test_file('test_core', '0.1.0', 'w') # delete the spec location attribute from the file with h5py.File(self.path, mode='r+') as f: del f.attrs[SPEC_LOC_ATTR] # load the namespace from file msg = "No cached namespaces found in %s" % self.path with self.assertWarnsWith(UserWarning, msg): ret = HDF5IO.get_namespaces(path=self.path) self.assertDictEqual(ret, {}) class TestExport(TestCase): """Test exporting HDF5 to HDF5 using HDF5IO.export_container_to_hdf5.""" def setUp(self): self.paths = [ get_temp_filepath(), get_temp_filepath(), get_temp_filepath(), get_temp_filepath(), ] self.ios = [] def tearDown(self): for io in self.ios: io.close_linked_files() for p in self.paths: if os.path.exists(p): os.remove(p) def test_basic(self): """Test that exporting a written container works.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export(src_io=read_io) self.assertTrue(os.path.exists(self.paths[1])) self.assertEqual(foofile.container_source, self.paths[0]) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: read_foofile = read_io.read() self.assertEqual(read_foofile.container_source, self.paths[1]) self.assertContainerEqual(foofile, read_foofile, ignore_hdmf_attrs=True) self.assertEqual(os.path.abspath(read_foofile.buckets['bucket1'].foos['foo1'].my_data.file.filename), self.paths[1]) def test_basic_container(self): """Test that exporting a written container, passing in the container arg, works.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: read_foofile = read_io.read() with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export(src_io=read_io, container=read_foofile) self.assertTrue(os.path.exists(self.paths[1])) self.assertEqual(foofile.container_source, self.paths[0]) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: read_foofile = read_io.read() self.assertEqual(read_foofile.container_source, self.paths[1]) self.assertContainerEqual(foofile, read_foofile, ignore_hdmf_attrs=True) def test_container_part(self): """Test that exporting a part of a written container raises an error.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: read_foofile = read_io.read() with HDF5IO(self.paths[1], mode='w') as export_io: msg = ("The provided container must be the root of the hierarchy of the source used to read the " "container.") with self.assertRaisesWith(ValueError, msg): export_io.export(src_io=read_io, container=read_foofile.buckets['bucket1']) def test_container_unknown(self): """Test that exporting a container that did not come from the src_io object raises an error.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: with HDF5IO(self.paths[1], mode='w') as export_io: dummy_file = FooFile([]) msg = "The provided container must have been read by the provided src_io." with self.assertRaisesWith(ValueError, msg): export_io.export(src_io=read_io, container=dummy_file) def test_cache_spec(self): """Test that exporting with cache_spec works.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: read_foofile = read_io.read() with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export( src_io=read_io, container=read_foofile, cache_spec=False, ) with File(self.paths[1], 'r') as f: self.assertNotIn('specifications', f) def test_soft_link_group(self): """Test that exporting a written file with soft linked groups keeps links within the file.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket], foo_link=foo1) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export(src_io=read_io) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() # make sure the linked group is within the same file self.assertEqual(read_foofile2.foo_link.container_source, self.paths[1]) def test_soft_link_dataset(self): """Test that exporting a written file with soft linked datasets keeps links within the file.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket], foofile_data=foo1.my_data) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export(src_io=read_io) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() # make sure the linked dataset is within the same file self.assertEqual(read_foofile2.foofile_data.file.filename, self.paths[1]) def test_external_link_group(self): """Test that exporting a written file with external linked groups maintains the links.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as read_io: read_io.write(foofile) manager = _get_manager() with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: read_foofile = read_io.read() # make external link to existing group foofile2 = FooFile(foo_link=read_foofile.buckets['bucket1'].foos['foo1']) with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: write_io.write(foofile2) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() with HDF5IO(self.paths[2], mode='w') as export_io: export_io.export(src_io=read_io) with HDF5IO(self.paths[2], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() # make sure the linked group is read from the first file self.assertEqual(read_foofile2.foo_link.container_source, self.paths[0]) def test_external_link_dataset(self): """Test that exporting a written file with external linked datasets maintains the links.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket], foofile_data=[1, 2, 3]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) manager = _get_manager() with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: read_foofile = read_io.read() foofile2 = FooFile(foofile_data=read_foofile.foofile_data) # make external link to existing dataset with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: write_io.write(foofile2) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown with HDF5IO(self.paths[2], mode='w') as export_io: export_io.export(src_io=read_io) with HDF5IO(self.paths[2], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() # make sure the linked dataset is read from the first file self.assertEqual(read_foofile2.foofile_data.file.filename, self.paths[0]) def test_external_link_link(self): """Test that exporting a written file with external links to external links maintains the links.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) manager = _get_manager() with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io: read_foofile = read_io.read() # make external link to existing group foofile2 = FooFile(foo_link=read_foofile.buckets['bucket1'].foos['foo1']) with HDF5IO(self.paths[1], manager=manager, mode='w') as write_io: write_io.write(foofile2) manager = _get_manager() with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() foofile3 = FooFile(foo_link=read_foofile2.foo_link) # make external link to external link with HDF5IO(self.paths[2], manager=manager, mode='w') as write_io: write_io.write(foofile3) with HDF5IO(self.paths[2], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown with HDF5IO(self.paths[3], mode='w') as export_io: export_io.export(src_io=read_io) with HDF5IO(self.paths[3], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile3 = read_io.read() # make sure the linked group is read from the first file self.assertEqual(read_foofile3.foo_link.container_source, self.paths[0]) def test_attr_reference(self): """Test that exporting a written file with attribute references maintains the references.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket], foo_ref_attr=foo1) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as read_io: read_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export(src_io=read_io) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: read_foofile2 = read_io.read() # make sure the attribute reference resolves to the container within the same file self.assertIs(read_foofile2.foo_ref_attr, read_foofile2.buckets['bucket1'].foos['foo1']) with File(self.paths[1], 'r') as f: self.assertIsInstance(f.attrs['foo_ref_attr'], h5py.Reference) def test_pop_data(self): """Test that exporting a written container after removing an element from it works.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: read_foofile = read_io.read() read_foofile.remove_bucket('bucket1') # remove child group with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export(src_io=read_io, container=read_foofile) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: read_foofile2 = read_io.read() # make sure the read foofile has no buckets self.assertDictEqual(read_foofile2.buckets, {}) # check that file size of file 2 is smaller self.assertTrue(os.path.getsize(self.paths[0]) > os.path.getsize(self.paths[1])) def test_pop_linked_group(self): """Test that exporting a written container after removing a linked element from it works.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket], foo_link=foo1) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: read_foofile = read_io.read() read_foofile.buckets['bucket1'].remove_foo('foo1') # remove child group with HDF5IO(self.paths[1], mode='w') as export_io: msg = ("links (links): Linked Foo 'foo1' has no parent. Remove the link or ensure the linked " "container is added properly.") with self.assertRaisesWith(OrphanContainerBuildError, msg): export_io.export(src_io=read_io, container=read_foofile) def test_append_data(self): """Test that exporting a written container after adding groups, links, and references to it works.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: read_foofile = read_io.read() # create a foo with link to existing dataset my_data, add the foo to new foobucket # this should make a soft link within the exported file foo2 = Foo('foo2', read_foofile.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 17, 3.14) foobucket2 = FooBucket('bucket2', [foo2]) read_foofile.add_bucket(foobucket2) # also add link from foofile to new foo2 container read_foofile.foo_link = foo2 # also add link from foofile to new foo2.my_data dataset which is a link to foo1.my_data dataset read_foofile.foofile_data = foo2.my_data # also add reference from foofile to new foo2 read_foofile.foo_ref_attr = foo2 with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export(src_io=read_io, container=read_foofile) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: self.ios.append(read_io) # track IO objects for tearDown read_foofile2 = read_io.read() # test new soft link to dataset in file self.assertIs(read_foofile2.buckets['bucket1'].foos['foo1'].my_data, read_foofile2.buckets['bucket2'].foos['foo2'].my_data) # test new soft link to group in file self.assertIs(read_foofile2.foo_link, read_foofile2.buckets['bucket2'].foos['foo2']) # test new soft link to new soft link to dataset in file self.assertIs(read_foofile2.buckets['bucket1'].foos['foo1'].my_data, read_foofile2.foofile_data) # test new attribute reference to new group in file self.assertIs(read_foofile2.foo_ref_attr, read_foofile2.buckets['bucket2'].foos['foo2']) with File(self.paths[1], 'r') as f: self.assertEqual(f['foofile_data'].file.filename, self.paths[1]) self.assertIsInstance(f.attrs['foo_ref_attr'], h5py.Reference) def test_append_external_link_data(self): """Test that exporting a written container after adding a link with link_data=True creates external links.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) foofile2 = FooFile([]) with HDF5IO(self.paths[1], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile2) manager = _get_manager() with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile1 = read_io1.read() with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io2: self.ios.append(read_io2) read_foofile2 = read_io2.read() # create a foo with link to existing dataset my_data (not in same file), add the foo to new foobucket # this should make an external link within the exported file foo2 = Foo('foo2', read_foofile1.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 17, 3.14) foobucket2 = FooBucket('bucket2', [foo2]) read_foofile2.add_bucket(foobucket2) # also add link from foofile to new foo2.my_data dataset which is a link to foo1.my_data dataset # this should make an external link within the exported file read_foofile2.foofile_data = foo2.my_data with HDF5IO(self.paths[2], mode='w') as export_io: export_io.export(src_io=read_io2, container=read_foofile2) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile3 = read_io1.read() with HDF5IO(self.paths[2], manager=_get_manager(), mode='r') as read_io2: self.ios.append(read_io2) # track IO objects for tearDown read_foofile4 = read_io2.read() self.assertEqual(read_foofile4.buckets['bucket2'].foos['foo2'].my_data, read_foofile3.buckets['bucket1'].foos['foo1'].my_data) self.assertEqual(read_foofile4.foofile_data, read_foofile3.buckets['bucket1'].foos['foo1'].my_data) with File(self.paths[2], 'r') as f: self.assertEqual(f['buckets/bucket2/foo_holder/foo2/my_data'].file.filename, self.paths[0]) self.assertEqual(f['foofile_data'].file.filename, self.paths[0]) self.assertIsInstance(f.get('buckets/bucket2/foo_holder/foo2/my_data', getlink=True), h5py.ExternalLink) self.assertIsInstance(f.get('foofile_data', getlink=True), h5py.ExternalLink) def test_append_external_link_copy_data(self): """Test that exporting a written container after adding a link with link_data=False copies the data.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) foofile2 = FooFile([]) with HDF5IO(self.paths[1], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile2) manager = _get_manager() with HDF5IO(self.paths[0], manager=manager, mode='r') as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile1 = read_io1.read() with HDF5IO(self.paths[1], manager=manager, mode='r') as read_io2: self.ios.append(read_io2) read_foofile2 = read_io2.read() # create a foo with link to existing dataset my_data (not in same file), add the foo to new foobucket # this would normally make an external link but because link_data=False, data will be copied foo2 = Foo('foo2', read_foofile1.buckets['bucket1'].foos['foo1'].my_data, "I am foo2", 17, 3.14) foobucket2 = FooBucket('bucket2', [foo2]) read_foofile2.add_bucket(foobucket2) # also add link from foofile to new foo2.my_data dataset which is a link to foo1.my_data dataset # this would normally make an external link but because link_data=False, data will be copied read_foofile2.foofile_data = foo2.my_data with HDF5IO(self.paths[2], mode='w') as export_io: export_io.export(src_io=read_io2, container=read_foofile2, write_args={'link_data': False}) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io1: self.ios.append(read_io1) # track IO objects for tearDown read_foofile3 = read_io1.read() with HDF5IO(self.paths[2], manager=_get_manager(), mode='r') as read_io2: self.ios.append(read_io2) # track IO objects for tearDown read_foofile4 = read_io2.read() # check that file can be read self.assertNotEqual(read_foofile4.buckets['bucket2'].foos['foo2'].my_data, read_foofile3.buckets['bucket1'].foos['foo1'].my_data) self.assertNotEqual(read_foofile4.foofile_data, read_foofile3.buckets['bucket1'].foos['foo1'].my_data) self.assertNotEqual(read_foofile4.foofile_data, read_foofile4.buckets['bucket2'].foos['foo2'].my_data) with File(self.paths[2], 'r') as f: self.assertEqual(f['buckets/bucket2/foo_holder/foo2/my_data'].file.filename, self.paths[2]) self.assertEqual(f['foofile_data'].file.filename, self.paths[2]) def test_export_io(self): """Test that exporting a written container using HDF5IO.export_io works.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], manager=_get_manager(), mode='r') as read_io: HDF5IO.export_io(src_io=read_io, path=self.paths[1]) self.assertTrue(os.path.exists(self.paths[1])) self.assertEqual(foofile.container_source, self.paths[0]) with HDF5IO(self.paths[1], manager=_get_manager(), mode='r') as read_io: read_foofile = read_io.read() self.assertEqual(read_foofile.container_source, self.paths[1]) self.assertContainerEqual(foofile, read_foofile, ignore_hdmf_attrs=True) def test_export_dset_refs(self): """Test that exporting a written container with a dataset of references works.""" bazs = [] num_bazs = 10 for i in range(num_bazs): bazs.append(Baz(name='baz%d' % i)) baz_data = BazData(name='baz_data1', data=bazs) bucket = BazBucket(name='bucket1', bazs=bazs.copy(), baz_data=baz_data) with HDF5IO(self.paths[0], manager=_get_baz_manager(), mode='w') as write_io: write_io.write(bucket) with HDF5IO(self.paths[0], manager=_get_baz_manager(), mode='r') as read_io: read_bucket1 = read_io.read() # NOTE: reference IDs might be the same between two identical files # adding a Baz with a smaller name should change the reference IDs on export new_baz = Baz(name='baz000') read_bucket1.add_baz(new_baz) with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export(src_io=read_io, container=read_bucket1) with HDF5IO(self.paths[1], manager=_get_baz_manager(), mode='r') as read_io: read_bucket2 = read_io.read() # remove and check the appended child, then compare the read container with the original read_new_baz = read_bucket2.remove_baz('baz000') self.assertContainerEqual(new_baz, read_new_baz, ignore_hdmf_attrs=True) self.assertContainerEqual(bucket, read_bucket2, ignore_name=True, ignore_hdmf_attrs=True) for i in range(num_bazs): baz_name = 'baz%d' % i self.assertIs(read_bucket2.baz_data.data[i], read_bucket2.bazs[baz_name]) def test_export_cpd_dset_refs(self): """Test that exporting a written container with a compound dataset with references works.""" bazs = [] baz_pairs = [] num_bazs = 10 for i in range(num_bazs): b = Baz(name='baz%d' % i) bazs.append(b) baz_pairs.append((i, b)) baz_cpd_data = BazCpdData(name='baz_cpd_data1', data=baz_pairs) bucket = BazBucket(name='bucket1', bazs=bazs.copy(), baz_cpd_data=baz_cpd_data) with HDF5IO(self.paths[0], manager=_get_baz_manager(), mode='w') as write_io: write_io.write(bucket) with HDF5IO(self.paths[0], manager=_get_baz_manager(), mode='r') as read_io: read_bucket1 = read_io.read() # NOTE: reference IDs might be the same between two identical files # adding a Baz with a smaller name should change the reference IDs on export new_baz = Baz(name='baz000') read_bucket1.add_baz(new_baz) with HDF5IO(self.paths[1], mode='w') as export_io: export_io.export(src_io=read_io, container=read_bucket1) with HDF5IO(self.paths[1], manager=_get_baz_manager(), mode='r') as read_io: read_bucket2 = read_io.read() # remove and check the appended child, then compare the read container with the original read_new_baz = read_bucket2.remove_baz(new_baz.name) self.assertContainerEqual(new_baz, read_new_baz, ignore_hdmf_attrs=True) self.assertContainerEqual(bucket, read_bucket2, ignore_name=True, ignore_hdmf_attrs=True) for i in range(num_bazs): baz_name = 'baz%d' % i self.assertEqual(read_bucket2.baz_cpd_data.data[i][0], i) self.assertIs(read_bucket2.baz_cpd_data.data[i][1], read_bucket2.bazs[baz_name]) def test_non_manager_container(self): """Test that exporting with a src_io without a manager raises an error.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) class OtherIO(HDMFIO): def read_builder(self): pass def write_builder(self, **kwargs): pass def open(self): pass def close(self): pass with OtherIO() as read_io: with HDF5IO(self.paths[1], mode='w') as export_io: msg = 'When a container is provided, src_io must have a non-None manager (BuildManager) property.' with self.assertRaisesWith(ValueError, msg): export_io.export(src_io=read_io, container=foofile, write_args={'link_data': False}) def test_non_HDF5_src_link_data_true(self): """Test that exporting with a src_io without a manager raises an error.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) class OtherIO(HDMFIO): def __init__(self, manager): super().__init__(manager=manager) def read_builder(self): pass def write_builder(self, **kwargs): pass def open(self): pass def close(self): pass with OtherIO(manager=_get_manager()) as read_io: with HDF5IO(self.paths[1], mode='w') as export_io: msg = "Cannot export from non-HDF5 backend OtherIO to HDF5 with write argument link_data=True." with self.assertRaisesWith(UnsupportedOperation, msg): export_io.export(src_io=read_io, container=foofile) def test_wrong_mode(self): """Test that exporting with a src_io without a manager raises an error.""" foo1 = Foo('foo1', [1, 2, 3, 4, 5], "I am foo1", 17, 3.14) foobucket = FooBucket('bucket1', [foo1]) foofile = FooFile([foobucket]) with HDF5IO(self.paths[0], manager=_get_manager(), mode='w') as write_io: write_io.write(foofile) with HDF5IO(self.paths[0], mode='r') as read_io: with HDF5IO(self.paths[1], mode='a') as export_io: msg = "Cannot export to file %s in mode 'a'. Please use mode 'w'." % self.paths[1] with self.assertRaisesWith(UnsupportedOperation, msg): export_io.export(src_io=read_io) class TestDatasetRefs(TestCase): def test_roundtrip(self): self.path = get_temp_filepath() bazs = [] num_bazs = 10 for i in range(num_bazs): bazs.append(Baz(name='baz%d' % i)) baz_data = BazData(name='baz_data1', data=bazs) bucket = BazBucket(name='bucket1', bazs=bazs.copy(), baz_data=baz_data) with HDF5IO(self.path, manager=_get_baz_manager(), mode='w') as write_io: write_io.write(bucket) with HDF5IO(self.path, manager=_get_baz_manager(), mode='r') as read_io: read_bucket = read_io.read() self.assertContainerEqual(bucket, read_bucket, ignore_name=True) for i in range(num_bazs): baz_name = 'baz%d' % i self.assertIs(read_bucket.baz_data.data[i], read_bucket.bazs[baz_name]) class TestCpdDatasetRefs(TestCase): def test_roundtrip(self): self.path = get_temp_filepath() bazs = [] baz_pairs = [] num_bazs = 10 for i in range(num_bazs): b = Baz(name='baz%d' % i) bazs.append(b) baz_pairs.append((i, b)) baz_cpd_data = BazCpdData(name='baz_cpd_data1', data=baz_pairs) bucket = BazBucket(name='bucket1', bazs=bazs.copy(), baz_cpd_data=baz_cpd_data) with HDF5IO(self.path, manager=_get_baz_manager(), mode='w') as write_io: write_io.write(bucket) with HDF5IO(self.path, manager=_get_baz_manager(), mode='r') as read_io: read_bucket = read_io.read() self.assertContainerEqual(bucket, read_bucket, ignore_name=True) for i in range(num_bazs): baz_name = 'baz%d' % i self.assertEqual(read_bucket.baz_cpd_data.data[i][0], i) self.assertIs(read_bucket.baz_cpd_data.data[i][1], read_bucket.bazs[baz_name]) class Baz(Container): pass class BazData(Data): pass class BazCpdData(Data): pass class BazBucket(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this bucket'}, {'name': 'bazs', 'type': list, 'doc': 'the Baz objects in this bucket'}, {'name': 'baz_data', 'type': BazData, 'doc': 'dataset of Baz references', 'default': None}, {'name': 'baz_cpd_data', 'type': BazCpdData, 'doc': 'dataset of Baz references', 'default': None}) def __init__(self, **kwargs): name, bazs, baz_data, baz_cpd_data = getargs('name', 'bazs', 'baz_data', 'baz_cpd_data', kwargs) super().__init__(name=name) self.__bazs = {b.name: b for b in bazs} # note: collections of groups are unordered in HDF5 for b in bazs: b.parent = self self.__baz_data = baz_data if self.__baz_data is not None: self.__baz_data.parent = self self.__baz_cpd_data = baz_cpd_data if self.__baz_cpd_data is not None: self.__baz_cpd_data.parent = self @property def bazs(self): return self.__bazs @property def baz_data(self): return self.__baz_data @property def baz_cpd_data(self): return self.__baz_cpd_data def add_baz(self, baz): self.__bazs[baz.name] = baz baz.parent = self def remove_baz(self, baz_name): baz = self.__bazs.pop(baz_name) self._remove_child(baz) return baz def _get_baz_manager(): baz_spec = GroupSpec( doc='A test group specification with a data type', data_type_def='Baz', ) baz_data_spec = DatasetSpec( doc='A test dataset of references specification with a data type', name='baz_data', data_type_def='BazData', dtype=RefSpec('Baz', 'object'), shape=[None], ) baz_cpd_data_spec = DatasetSpec( doc='A test compound dataset with references specification with a data type', name='baz_cpd_data', data_type_def='BazCpdData', dtype=[DtypeSpec(name='part1', doc='doc', dtype='int'), DtypeSpec(name='part2', doc='doc', dtype=RefSpec('Baz', 'object'))], shape=[None], ) baz_holder_spec = GroupSpec( doc='group of bazs', name='bazs', groups=[GroupSpec(doc='Baz', data_type_inc='Baz', quantity=ONE_OR_MANY)], ) baz_bucket_spec = GroupSpec( doc='A test group specification for a data type containing data type', data_type_def='BazBucket', groups=[baz_holder_spec], datasets=[DatasetSpec(doc='doc', data_type_inc='BazData', quantity=ZERO_OR_ONE), DatasetSpec(doc='doc', data_type_inc='BazCpdData', quantity=ZERO_OR_ONE)], ) spec_catalog = SpecCatalog() spec_catalog.register_spec(baz_spec, 'test.yaml') spec_catalog.register_spec(baz_data_spec, 'test.yaml') spec_catalog.register_spec(baz_cpd_data_spec, 'test.yaml') spec_catalog.register_spec(baz_bucket_spec, 'test.yaml') namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) type_map.register_container_type(CORE_NAMESPACE, 'Baz', Baz) type_map.register_container_type(CORE_NAMESPACE, 'BazData', BazData) type_map.register_container_type(CORE_NAMESPACE, 'BazCpdData', BazCpdData) type_map.register_container_type(CORE_NAMESPACE, 'BazBucket', BazBucket) class BazBucketMapper(ObjectMapper): def __init__(self, spec): super().__init__(spec) baz_holder_spec = spec.get_group('bazs') self.unmap(baz_holder_spec) baz_spec = baz_holder_spec.get_data_type('Baz') self.map_spec('bazs', baz_spec) type_map.register_map(BazBucket, BazBucketMapper) manager = BuildManager(type_map) return manager ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/test_multicontainerinterface.py0000644000655200065520000004452200000000000023240 0ustar00circlecicircleciimport inspect from hdmf.container import Container, Data, MultiContainerInterface from hdmf.testing import TestCase from hdmf.utils import LabelledDict, get_docval class OData(Data): pass class Foo(MultiContainerInterface): __clsconf__ = [ { 'attr': 'containers', 'add': 'add_container', 'type': (Container, ), 'get': 'get_container', }, { 'attr': 'data', 'add': 'add_data', 'type': (Data, OData), }, { 'attr': 'foo_data', 'add': 'add_foo_data', 'type': OData, 'create': 'create_foo_data', }, { 'attr': 'things', 'add': 'add_thing', 'type': (Container, Data, OData), }, ] class FooSingle(MultiContainerInterface): __clsconf__ = { 'attr': 'containers', 'add': 'add_container', 'type': (Container, ), } class Baz(MultiContainerInterface): __containers = dict() __clsconf__ = [ { 'attr': 'containers', 'add': 'add_container', 'type': Container, 'get': 'get_container', }, ] # use custom keyword arguments def __init__(self, name, other_arg, my_containers): super().__init__(name=name) self.other_arg = other_arg self.containers = {'my ' + v.name: v for v in my_containers} @property def containers(self): return self.__containers @containers.setter def containers(self, value): self.__containers = value class TestBasic(TestCase): def test_init_docval(self): """Test that the docval for the __init__ method is set correctly.""" dv = get_docval(Foo.__init__) self.assertEqual(dv[0]['name'], 'containers') self.assertEqual(dv[1]['name'], 'data') self.assertEqual(dv[2]['name'], 'foo_data') self.assertEqual(dv[3]['name'], 'things') self.assertTupleEqual(dv[0]['type'], (list, tuple, dict, Container)) self.assertTupleEqual(dv[1]['type'], (list, tuple, dict, Data, OData)) self.assertTupleEqual(dv[2]['type'], (list, tuple, dict, OData)) self.assertTupleEqual(dv[3]['type'], (list, tuple, dict, Container, Data, OData)) self.assertEqual(dv[0]['doc'], 'Container to store in this interface') self.assertEqual(dv[1]['doc'], 'Data or OData to store in this interface') self.assertEqual(dv[2]['doc'], 'OData to store in this interface') self.assertEqual(dv[3]['doc'], 'Container, Data, or OData to store in this interface') for i in range(4): self.assertDictEqual(dv[i]['default'], {}) self.assertEqual(dv[4]['name'], 'name') self.assertEqual(dv[4]['type'], str) self.assertEqual(dv[4]['doc'], 'the name of this container') self.assertEqual(dv[4]['default'], 'Foo') def test_add_docval(self): """Test that the docval for the add method is set correctly.""" dv = get_docval(Foo.add_container) self.assertEqual(dv[0]['name'], 'containers') self.assertTupleEqual(dv[0]['type'], (list, tuple, dict, Container)) self.assertEqual(dv[0]['doc'], 'the Container to add') self.assertFalse('default' in dv[0]) def test_create_docval(self): """Test that the docval for the create method is set correctly.""" dv = get_docval(Foo.create_foo_data) self.assertEqual(dv[0]['name'], 'name') self.assertEqual(dv[1]['name'], 'data') def test_getter_docval(self): """Test that the docval for the get method is set correctly.""" dv = get_docval(Foo.get_container) self.assertEqual(dv[0]['doc'], 'the name of the Container') self.assertIsNone(dv[0]['default']) def test_getitem_docval(self): """Test that the docval for __getitem__ is set correctly.""" dv = get_docval(Baz.__getitem__) self.assertEqual(dv[0]['doc'], 'the name of the Container') self.assertIsNone(dv[0]['default']) def test_attr_property(self): """Test that a property is created for the attribute.""" properties = inspect.getmembers(Foo, lambda o: isinstance(o, property)) match = [p for p in properties if p[0] == 'containers'] self.assertEqual(len(match), 1) def test_attr_getter(self): """Test that the getter for the attribute dict returns a LabelledDict.""" foo = Foo() self.assertTrue(isinstance(foo.containers, LabelledDict)) def test_init_empty(self): """Test that initializing the MCI with no arguments initializes the attribute dict empty.""" foo = Foo() self.assertDictEqual(foo.containers, {}) self.assertEqual(foo.name, 'Foo') def test_init_multi(self): """Test that initializing the MCI with no arguments initializes the attribute dict empty.""" obj1 = Container('obj1') data1 = Data('data1', [1, 2, 3]) foo = Foo(containers=obj1, data=data1) self.assertDictEqual(foo.containers, {'obj1': obj1}) self.assertDictEqual(foo.data, {'data1': data1}) def test_init_custom_name(self): """Test that initializing the MCI with a custom name works.""" foo = Foo(name='test_foo') self.assertEqual(foo.name, 'test_foo') # init, create, and setter calls add, so just test add def test_add_single(self): """Test that adding a container to the attribute dict correctly adds the container.""" obj1 = Container('obj1') foo = Foo() foo.add_container(obj1) self.assertDictEqual(foo.containers, {'obj1': obj1}) self.assertIs(obj1.parent, foo) def test_add_single_not_parent(self): """Test that adding a container with a parent to the attribute dict correctly adds the container.""" obj1 = Container('obj1') obj2 = Container('obj2') obj1.parent = obj2 foo = Foo() foo.add_container(obj1) self.assertDictEqual(foo.containers, {'obj1': obj1}) self.assertIs(obj1.parent, obj2) def test_add_single_dup(self): """Test that adding a container to the attribute dict correctly adds the container.""" obj1 = Container('obj1') foo = Foo(obj1) msg = "'obj1' already exists in Foo 'Foo'" with self.assertRaisesWith(ValueError, msg): foo.add_container(obj1) def test_add_list(self): """Test that adding a list to the attribute dict correctly adds the items.""" obj1 = Container('obj1') obj2 = Container('obj2') foo = Foo() foo.add_container([obj1, obj2]) self.assertDictEqual(foo.containers, {'obj1': obj1, 'obj2': obj2}) def test_add_dict(self): """Test that adding a dict to the attribute dict correctly adds the input dict values.""" obj1 = Container('obj1') obj2 = Container('obj2') foo = Foo() foo.add_container({'a': obj1, 'b': obj2}) self.assertDictEqual(foo.containers, {'obj1': obj1, 'obj2': obj2}) def test_attr_setter_none(self): """Test that setting the attribute dict to None does not alter the dict.""" obj1 = Container('obj1') foo = Foo(obj1) foo.containers = None self.assertDictEqual(foo.containers, {'obj1': obj1}) def test_remove_child(self): """Test that removing a child container from the attribute dict resets the parent to None.""" obj1 = Container('obj1') foo = Foo(obj1) del foo.containers['obj1'] self.assertDictEqual(foo.containers, {}) self.assertIsNone(obj1.parent) def test_remove_non_child(self): """Test that removing a non-child container from the attribute dict resets the parent to None.""" obj1 = Container('obj1') obj2 = Container('obj2') obj1.parent = obj2 foo = Foo(obj1) del foo.containers['obj1'] self.assertDictEqual(foo.containers, {}) self.assertIs(obj1.parent, obj2) def test_getter_empty(self): """Test that calling the getter with no args and no items in the attribute dict raises an error.""" foo = Foo() msg = "containers of Foo 'Foo' is empty." with self.assertRaisesWith(ValueError, msg): foo.get_container() def test_getter_none(self): """Test that calling the getter with no args and one item in the attribute returns the item.""" obj1 = Container('obj1') foo = Foo(obj1) self.assertIs(foo.get_container(), obj1) def test_getter_none_multiple(self): """Test that calling the getter with no args and multiple items in the attribute dict raises an error.""" obj1 = Container('obj1') obj2 = Container('obj2') foo = Foo([obj1, obj2]) msg = "More than one element in containers of Foo 'Foo' -- must specify a name." with self.assertRaisesWith(ValueError, msg): foo.get_container() def test_getter_name(self): """Test that calling the getter with a correct key works.""" obj1 = Container('obj1') foo = Foo(obj1) self.assertIs(foo.get_container('obj1'), obj1) def test_getter_name_not_found(self): """Test that calling the getter with a key not in the attribute dict raises a KeyError.""" foo = Foo() msg = "\"'obj1' not found in containers of Foo 'Foo'.\"" with self.assertRaisesWith(KeyError, msg): foo.get_container('obj1') def test_getitem_multiconf(self): """Test that classes with multiple attribute configurations cannot use getitem.""" foo = Foo() msg = "'Foo' object is not subscriptable" with self.assertRaisesWith(TypeError, msg): foo['aa'] def test_getitem(self): """Test that getitem works.""" obj1 = Container('obj1') foo = FooSingle(obj1) self.assertIs(foo['obj1'], obj1) def test_getitem_single_none(self): """Test that getitem works wwhen there is a single item and no name is given to getitem.""" obj1 = Container('obj1') foo = FooSingle(obj1) self.assertIs(foo[None], obj1) def test_getitem_empty(self): """Test that an error is raised if the attribute dict is empty and no name is given to getitem.""" foo = FooSingle() msg = "FooSingle 'FooSingle' is empty." with self.assertRaisesWith(ValueError, msg): foo[None] def test_getitem_multiple(self): """Test that an error is raised if the attribute dict has multiple values and no name is given to getitem.""" obj1 = Container('obj1') obj2 = Container('obj2') foo = FooSingle([obj1, obj2]) msg = "More than one Container in FooSingle 'FooSingle' -- must specify a name." with self.assertRaisesWith(ValueError, msg): foo[None] def test_getitem_not_found(self): """Test that a KeyError is raised if the key is not found using getitem.""" obj1 = Container('obj1') foo = FooSingle(obj1) msg = "\"'obj2' not found in FooSingle 'FooSingle'.\"" with self.assertRaisesWith(KeyError, msg): foo['obj2'] class TestOverrideInit(TestCase): def test_override_init(self): """Test that overriding __init__ works.""" obj1 = Container('obj1') obj2 = Container('obj2') containers = [obj1, obj2] baz = Baz(name='test_baz', other_arg=1, my_containers=containers) self.assertEqual(baz.name, 'test_baz') self.assertEqual(baz.other_arg, 1) def test_override_property(self): """Test that overriding the attribute property works.""" obj1 = Container('obj1') obj2 = Container('obj2') containers = [obj1, obj2] baz = Baz(name='test_baz', other_arg=1, my_containers=containers) self.assertDictEqual(baz.containers, {'my obj1': obj1, 'my obj2': obj2}) self.assertFalse(isinstance(baz.containers, LabelledDict)) self.assertIs(baz.get_container('my obj1'), obj1) baz.containers = {} self.assertDictEqual(baz.containers, {}) class TestNoClsConf(TestCase): def test_mci_init(self): """Test that MultiContainerInterface cannot be instantiated.""" msg = "Can't instantiate class MultiContainerInterface." with self.assertRaisesWith(TypeError, msg): MultiContainerInterface(name='a') def test_init_no_cls_conf(self): """Test that defining an MCI subclass without __clsconf__ raises an error.""" class Bar(MultiContainerInterface): pass msg = ("MultiContainerInterface subclass Bar is missing __clsconf__ attribute. Please check that " "the class is properly defined.") with self.assertRaisesWith(TypeError, msg): Bar(name='a') def test_init_superclass_no_cls_conf(self): """Test that a subclass of an MCI class without a __clsconf__ can be initialized.""" class Bar(MultiContainerInterface): pass class Qux(Bar): __clsconf__ = { 'attr': 'containers', 'add': 'add_container', 'type': Container, } obj1 = Container('obj1') qux = Qux(obj1) self.assertDictEqual(qux.containers, {'obj1': obj1}) class TestBadClsConf(TestCase): def test_wrong_type(self): """Test that an error is raised if __clsconf__ is missing the add key.""" msg = "'__clsconf__' for MultiContainerInterface subclass Bar must be a dict or a list of dicts." with self.assertRaisesWith(TypeError, msg): class Bar(MultiContainerInterface): __clsconf__ = ( { 'attr': 'data', 'add': 'add_data', 'type': (Data, ), }, ) def test_missing_add(self): """Test that an error is raised if __clsconf__ is missing the add key.""" msg = "MultiContainerInterface subclass Bar is missing 'add' key in __clsconf__" with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): __clsconf__ = {} def test_missing_attr(self): """Test that an error is raised if __clsconf__ is missing the attr key.""" msg = "MultiContainerInterface subclass Bar is missing 'attr' key in __clsconf__" with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): __clsconf__ = { 'add': 'add_container', } def test_missing_type(self): """Test that an error is raised if __clsconf__ is missing the type key.""" msg = "MultiContainerInterface subclass Bar is missing 'type' key in __clsconf__" with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): __clsconf__ = { 'add': 'add_container', 'attr': 'containers', } def test_create_multiple_types(self): """Test that an error is raised if __clsconf__ specifies 'create' key with multiple types.""" msg = ("Cannot specify 'create' key in __clsconf__ for MultiContainerInterface subclass Bar " "when 'type' key is not a single type") with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): __clsconf__ = { 'attr': 'data', 'add': 'add_data', 'type': (Data, ), 'create': 'create_data', } def test_missing_add_multi(self): """Test that an error is raised if one item of a __clsconf__ list is missing the add key.""" msg = "MultiContainerInterface subclass Bar is missing 'add' key in __clsconf__ at index 1" with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): __clsconf__ = [ { 'attr': 'data', 'add': 'add_data', 'type': (Data, ), }, {} ] def test_missing_attr_multi(self): """Test that an error is raised if one item of a __clsconf__ list is missing the attr key.""" msg = "MultiContainerInterface subclass Bar is missing 'attr' key in __clsconf__ at index 1" with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): __clsconf__ = [ { 'attr': 'data', 'add': 'add_data', 'type': (Data, ), }, { 'add': 'add_container', } ] def test_missing_type_multi(self): """Test that an error is raised if one item of a __clsconf__ list is missing the type key.""" msg = "MultiContainerInterface subclass Bar is missing 'type' key in __clsconf__ at index 1" with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): __clsconf__ = [ { 'attr': 'data', 'add': 'add_data', 'type': (Data, ), }, { 'add': 'add_container', 'attr': 'containers', } ] def test_create_multiple_types_multi(self): """Test that an error is raised if one item of a __clsconf__ list specifies 'create' key with multiple types.""" msg = ("Cannot specify 'create' key in __clsconf__ for MultiContainerInterface subclass Bar " "when 'type' key is not a single type at index 1") with self.assertRaisesWith(ValueError, msg): class Bar(MultiContainerInterface): __clsconf__ = [ { 'attr': 'data', 'add': 'add_data', 'type': (Data, ), }, { 'add': 'add_container', 'attr': 'containers', 'type': (Container, ), 'create': 'create_container', } ] ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/test_query.py0000644000655200065520000001121000000000000017453 0ustar00circlecicircleciimport os from abc import ABCMeta, abstractmethod import numpy as np from h5py import File from hdmf.array import SortedArray, LinSpace from hdmf.query import HDMFDataset, Query from hdmf.testing import TestCase class AbstractQueryMixin(metaclass=ABCMeta): @abstractmethod def getDataset(self): raise NotImplementedError('Cannot run test unless getDataset is implemented') def setUp(self): self.dset = self.getDataset() self.wrapper = HDMFDataset(self.dset) def test_get_dataset(self): array = self.wrapper.dataset self.assertIsInstance(array, SortedArray) def test___gt__(self): ''' Test wrapper greater than magic method ''' q = self.wrapper > 5 self.assertIsInstance(q, Query) result = q.evaluate() expected = [False, False, False, False, False, False, True, True, True, True] expected = slice(6, 10) self.assertEqual(result, expected) def test___ge__(self): ''' Test wrapper greater than or equal magic method ''' q = self.wrapper >= 5 self.assertIsInstance(q, Query) result = q.evaluate() expected = [False, False, False, False, False, True, True, True, True, True] expected = slice(5, 10) self.assertEqual(result, expected) def test___lt__(self): ''' Test wrapper less than magic method ''' q = self.wrapper < 5 self.assertIsInstance(q, Query) result = q.evaluate() expected = [True, True, True, True, True, False, False, False, False, False] expected = slice(0, 5) self.assertEqual(result, expected) def test___le__(self): ''' Test wrapper less than or equal magic method ''' q = self.wrapper <= 5 self.assertIsInstance(q, Query) result = q.evaluate() expected = [True, True, True, True, True, True, False, False, False, False] expected = slice(0, 6) self.assertEqual(result, expected) def test___eq__(self): ''' Test wrapper equals magic method ''' q = self.wrapper == 5 self.assertIsInstance(q, Query) result = q.evaluate() expected = [False, False, False, False, False, True, False, False, False, False] expected = 5 self.assertTrue(np.array_equal(result, expected)) def test___ne__(self): ''' Test wrapper not equal magic method ''' q = self.wrapper != 5 self.assertIsInstance(q, Query) result = q.evaluate() expected = [True, True, True, True, True, False, True, True, True, True] expected = [slice(0, 5), slice(6, 10)] self.assertTrue(np.array_equal(result, expected)) def test___getitem__(self): ''' Test wrapper getitem using slice ''' result = self.wrapper[0:5] expected = [0, 1, 2, 3, 4] self.assertTrue(np.array_equal(result, expected)) def test___getitem__query(self): ''' Test wrapper getitem using query ''' q = self.wrapper < 5 result = self.wrapper[q] expected = [0, 1, 2, 3, 4] self.assertTrue(np.array_equal(result, expected)) class SortedQueryTest(AbstractQueryMixin, TestCase): path = 'SortedQueryTest.h5' def getDataset(self): self.f = File(self.path, 'w') self.input = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] self.d = self.f.create_dataset('dset', data=self.input) return SortedArray(self.d) def tearDown(self): self.f.close() if os.path.exists(self.path): os.remove(self.path) class LinspaceQueryTest(AbstractQueryMixin, TestCase): path = 'LinspaceQueryTest.h5' def getDataset(self): return LinSpace(0, 10, 1) class CompoundQueryTest(TestCase): def getM(self): return SortedArray(np.arange(10, 20, 1)) def getN(self): return SortedArray(np.arange(10.0, 20.0, 0.5)) def setUp(self): self.m = HDMFDataset(self.getM()) self.n = HDMFDataset(self.getN()) # TODO: test not completed # def test_map(self): # q = self.m == (12, 16) # IN operation # q.evaluate() # [2,3,4,5] # q.evaluate(False) # RangeResult(2,6) # r = self.m[q] # noqa: F841 # r = self.m[q.evaluate()] # noqa: F841 # r = self.m[q.evaluate(False)] # noqa: F841 def tearDown(self): pass ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/test_table.py0000644000655200065520000000732100000000000017405 0ustar00circlecicircleciimport pandas as pd from hdmf.container import Table, Row, RowGetter from hdmf.testing import TestCase class TestTable(TestCase): @classmethod def get_table_class(cls): class MyTable(Table): __defaultname__ = 'my_table' __columns__ = [ {'name': 'col1', 'type': str, 'help': 'a string column'}, {'name': 'col2', 'type': int, 'help': 'an integer column'}, ] return MyTable def test_init(self): MyTable = TestTable.get_table_class() table = MyTable('test_table') self.assertTrue(hasattr(table, '__colidx__')) self.assertEqual(table.__colidx__, {'col1': 0, 'col2': 1}) def test_add_row_getitem(self): MyTable = TestTable.get_table_class() table = MyTable('test_table') table.add_row(col1='foo', col2=100) table.add_row(col1='bar', col2=200) row1 = table[0] row2 = table[1] self.assertEqual(row1, ('foo', 100)) self.assertEqual(row2, ('bar', 200)) def test_to_dataframe(self): MyTable = TestTable.get_table_class() table = MyTable('test_table') table.add_row(col1='foo', col2=100) table.add_row(col1='bar', col2=200) df = table.to_dataframe() exp = pd.DataFrame(data=[{'col1': 'foo', 'col2': 100}, {'col1': 'bar', 'col2': 200}]) pd.testing.assert_frame_equal(df, exp) def test_from_dataframe(self): MyTable = TestTable.get_table_class() exp = pd.DataFrame(data=[{'col1': 'foo', 'col2': 100}, {'col1': 'bar', 'col2': 200}]) table = MyTable.from_dataframe(exp) row1 = table[0] row2 = table[1] self.assertEqual(row1, ('foo', 100)) self.assertEqual(row2, ('bar', 200)) class TestRow(TestCase): def setUp(self): self.MyTable = TestTable.get_table_class() class MyRow(Row): __table__ = self.MyTable self.MyRow = MyRow self.table = self.MyTable('test_table') def test_row_no_table(self): with self.assertRaisesRegex(ValueError, '__table__ must be set if sub-classing Row'): class MyRow(Row): pass def test_table_init(self): MyTable = TestTable.get_table_class() table = MyTable('test_table') self.assertFalse(hasattr(table, 'row')) table_w_row = self.MyTable('test_table') self.assertTrue(hasattr(table_w_row, 'row')) self.assertIsInstance(table_w_row.row, RowGetter) self.assertIs(table_w_row.row.table, table_w_row) def test_init(self): row1 = self.MyRow(col1='foo', col2=100, table=self.table) # make sure Row object set up properly self.assertEqual(row1.idx, 0) self.assertEqual(row1.col1, 'foo') self.assertEqual(row1.col2, 100) # make sure Row object is stored in Table peroperly tmp_row1 = self.table.row[0] self.assertEqual(tmp_row1, row1) def test_add_row_getitem(self): self.table.add_row(col1='foo', col2=100) self.table.add_row(col1='bar', col2=200) row1 = self.table.row[0] self.assertIsInstance(row1, self.MyRow) self.assertEqual(row1.idx, 0) self.assertEqual(row1.col1, 'foo') self.assertEqual(row1.col2, 100) row2 = self.table.row[1] self.assertIsInstance(row2, self.MyRow) self.assertEqual(row2.idx, 1) self.assertEqual(row2.col1, 'bar') self.assertEqual(row2.col2, 200) # test memoization row3 = self.table.row[0] self.assertIs(row3, row1) def test_todict(self): row1 = self.MyRow(col1='foo', col2=100, table=self.table) self.assertEqual(row1.todict(), {'col1': 'foo', 'col2': 100}) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/utils.py0000644000655200065520000002117000000000000016415 0ustar00circlecicircleciimport os import tempfile from copy import copy, deepcopy from hdmf.build import TypeMap from hdmf.container import Container from hdmf.spec import GroupSpec, DatasetSpec, NamespaceCatalog, SpecCatalog, SpecNamespace, NamespaceBuilder from hdmf.utils import docval, getargs, get_docval CORE_NAMESPACE = 'test_core' class Foo(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this Foo'}, {'name': 'my_data', 'type': ('array_data', 'data'), 'doc': 'some data'}, {'name': 'attr1', 'type': str, 'doc': 'an attribute'}, {'name': 'attr2', 'type': int, 'doc': 'another attribute'}, {'name': 'attr3', 'type': float, 'doc': 'a third attribute', 'default': 3.14}) def __init__(self, **kwargs): name, my_data, attr1, attr2, attr3 = getargs('name', 'my_data', 'attr1', 'attr2', 'attr3', kwargs) super().__init__(name=name) self.__data = my_data self.__attr1 = attr1 self.__attr2 = attr2 self.__attr3 = attr3 def __eq__(self, other): attrs = ('name', 'my_data', 'attr1', 'attr2', 'attr3') return all(getattr(self, a) == getattr(other, a) for a in attrs) def __str__(self): attrs = ('name', 'my_data', 'attr1', 'attr2', 'attr3') return '<' + ','.join('%s=%s' % (a, getattr(self, a)) for a in attrs) + '>' @property def my_data(self): return self.__data @property def attr1(self): return self.__attr1 @property def attr2(self): return self.__attr2 @property def attr3(self): return self.__attr3 def __hash__(self): return hash(self.name) class FooBucket(Container): @docval({'name': 'name', 'type': str, 'doc': 'the name of this bucket'}, {'name': 'foos', 'type': list, 'doc': 'the Foo objects in this bucket', 'default': list()}) def __init__(self, **kwargs): name, foos = getargs('name', 'foos', kwargs) super().__init__(name=name) self.__foos = {f.name: f for f in foos} # note: collections of groups are unordered in HDF5 for f in foos: f.parent = self def __eq__(self, other): return self.name == other.name and self.foos == other.foos def __str__(self): return 'name=%s, foos=%s' % (self.name, self.foos) @property def foos(self): return self.__foos def remove_foo(self, foo_name): foo = self.__foos.pop(foo_name) if foo.parent is self: self._remove_child(foo) return foo def get_temp_filepath(): # On Windows, h5py cannot truncate an open file in write mode. # The temp file will be closed before h5py truncates it and will be removed during the tearDown step. temp_file = tempfile.NamedTemporaryFile() temp_file.close() return temp_file.name def create_test_type_map(specs, container_classes, mappers=None): """ Create a TypeMap with the specs registered under a test namespace, and classes and mappers registered to type names. :param specs: list of specs :param container_classes: dict of type name to container class :param mappers: (optional) dict of type name to mapper class :return: the constructed TypeMap """ spec_catalog = SpecCatalog() schema_file = 'test.yaml' for s in specs: spec_catalog.register_spec(s, schema_file) namespace = SpecNamespace( doc='a test namespace', name=CORE_NAMESPACE, schema=[{'source': schema_file}], version='0.1.0', catalog=spec_catalog ) namespace_catalog = NamespaceCatalog() namespace_catalog.add_namespace(CORE_NAMESPACE, namespace) type_map = TypeMap(namespace_catalog) for type_name, container_cls in container_classes.items(): type_map.register_container_type(CORE_NAMESPACE, type_name, container_cls) if mappers: for type_name, mapper_cls in mappers.items(): container_cls = container_classes[type_name] type_map.register_map(container_cls, mapper_cls) return type_map def create_load_namespace_yaml(namespace_name, specs, output_dir, incl_types, type_map): """ Create a TypeMap with the specs loaded from YAML files and dependencies resolved. This writes namespaces and specs to YAML files, creates an empty TypeMap, and calls load_namespaces on the TypeMap, instead of manually creating a SpecCatalog, SpecNamespace, NamespaceCatalog and manually registering container types. Importantly, this process resolves dependencies across namespaces. :param namespace_name: Name of the new namespace. :param specs: List of specs of new data types to add. :param incl_types: Dict mapping included namespace name to list of data types to include or None to include all. :param type_map: The type map to load the namespace into. """ ns_builder = NamespaceBuilder( name=namespace_name, doc='a test namespace', version='0.1.0', ) ns_filename = ns_builder.name + '.namespace.yaml' ext_filename = ns_builder.name + '.extensions.yaml' for ns, types in incl_types.items(): if types is None: # include all types ns_builder.include_namespace(ns) else: for dt in types: ns_builder.include_type(dt, namespace=ns) for data_type in specs: ns_builder.add_spec(ext_filename, data_type) ns_builder.export(ns_filename, outdir=output_dir) ns_path = os.path.join(output_dir, ns_filename) type_map.load_namespaces(ns_path) # ##### custom spec classes ##### def swap_inc_def(cls, custom_cls): args = get_docval(cls.__init__) ret = list() for arg in args: if arg['name'] == 'data_type_def': ret.append({'name': 'my_data_type_def', 'type': str, 'doc': 'the NWB data type this spec defines', 'default': None}) elif arg['name'] == 'data_type_inc': ret.append({'name': 'my_data_type_inc', 'type': (custom_cls, str), 'doc': 'the NWB data type this spec includes', 'default': None}) else: ret.append(copy(arg)) return ret class BaseStorageOverride: __type_key = 'my_data_type' __inc_key = 'my_data_type_inc' __def_key = 'my_data_type_def' @classmethod def type_key(cls): ''' Get the key used to store data type on an instance''' return cls.__type_key @classmethod def inc_key(cls): ''' Get the key used to define a data_type include.''' return cls.__inc_key @classmethod def def_key(cls): ''' Get the key used to define a data_type definition.''' return cls.__def_key @classmethod def build_const_args(cls, spec_dict): """Extend base functionality to remap data_type_def and data_type_inc keys""" spec_dict = copy(spec_dict) proxy = super(BaseStorageOverride, cls) if proxy.inc_key() in spec_dict: spec_dict[cls.inc_key()] = spec_dict.pop(proxy.inc_key()) if proxy.def_key() in spec_dict: spec_dict[cls.def_key()] = spec_dict.pop(proxy.def_key()) ret = proxy.build_const_args(spec_dict) return ret @classmethod def _translate_kwargs(cls, kwargs): """Swap mydata_type_def and mydata_type_inc for data_type_def and data_type_inc, respectively""" proxy = super(BaseStorageOverride, cls) kwargs[proxy.def_key()] = kwargs.pop(cls.def_key()) kwargs[proxy.inc_key()] = kwargs.pop(cls.inc_key()) return kwargs class CustomGroupSpec(BaseStorageOverride, GroupSpec): @docval(*deepcopy(swap_inc_def(GroupSpec, 'CustomGroupSpec'))) def __init__(self, **kwargs): kwargs = self._translate_kwargs(kwargs) super().__init__(**kwargs) @classmethod def dataset_spec_cls(cls): return CustomDatasetSpec @docval(*deepcopy(swap_inc_def(GroupSpec, 'CustomGroupSpec'))) def add_group(self, **kwargs): spec = CustomGroupSpec(**kwargs) self.set_group(spec) return spec @docval(*deepcopy(swap_inc_def(DatasetSpec, 'CustomDatasetSpec'))) def add_dataset(self, **kwargs): ''' Add a new specification for a subgroup to this group specification ''' spec = CustomDatasetSpec(**kwargs) self.set_dataset(spec) return spec class CustomDatasetSpec(BaseStorageOverride, DatasetSpec): @docval(*deepcopy(swap_inc_def(DatasetSpec, 'CustomDatasetSpec'))) def __init__(self, **kwargs): kwargs = self._translate_kwargs(kwargs) super().__init__(**kwargs) class CustomSpecNamespace(SpecNamespace): __types_key = 'my_data_types' @classmethod def types_key(cls): return cls.__types_key ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1886272 hdmf-3.1.1/tests/unit/utils_test/0000755000655200065520000000000000000000000017101 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/utils_test/__init__.py0000644000655200065520000000000000000000000021200 0ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/utils_test/test_core_DataChunk.py0000644000655200065520000000220200000000000023360 0ustar00circlecicirclecifrom copy import copy, deepcopy import numpy as np from hdmf.data_utils import DataChunk from hdmf.testing import TestCase class DataChunkTests(TestCase): def setUp(self): pass def tearDown(self): pass def test_datachunk_copy(self): obj = DataChunk(data=np.arange(3), selection=np.s_[0:3]) obj_copy = copy(obj) self.assertNotEqual(id(obj), id(obj_copy)) self.assertEqual(id(obj.data), id(obj_copy.data)) self.assertEqual(id(obj.selection), id(obj_copy.selection)) def test_datachunk_deepcopy(self): obj = DataChunk(data=np.arange(3), selection=np.s_[0:3]) obj_copy = deepcopy(obj) self.assertNotEqual(id(obj), id(obj_copy)) self.assertNotEqual(id(obj.data), id(obj_copy.data)) self.assertNotEqual(id(obj.selection), id(obj_copy.selection)) def test_datachunk_astype(self): obj = DataChunk(data=np.arange(3), selection=np.s_[0:3]) newtype = np.dtype('int16') obj_astype = obj.astype(newtype) self.assertNotEqual(id(obj), id(obj_astype)) self.assertEqual(obj_astype.dtype, np.dtype(newtype)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/utils_test/test_core_DataChunkIterator.py0000644000655200065520000004476100000000000025112 0ustar00circlecicircleciimport numpy as np from hdmf.data_utils import DataChunkIterator, DataChunk from hdmf.testing import TestCase class DataChunkIteratorTests(TestCase): def setUp(self): pass def tearDown(self): pass def test_none_iter(self): """Test that DataChunkIterator __init__ sets defaults correctly and all chunks and recommended shapes are None. """ dci = DataChunkIterator(dtype=np.dtype('int')) self.assertIsNone(dci.maxshape) self.assertEqual(dci.dtype, np.dtype('int')) self.assertEqual(dci.buffer_size, 1) self.assertEqual(dci.iter_axis, 0) count = 0 for chunk in dci: pass self.assertEqual(count, 0) self.assertIsNone(dci.recommended_data_shape()) self.assertIsNone(dci.recommended_chunk_shape()) def test_list_none(self): """Test that DataChunkIterator has no dtype or chunks when given a list of None. """ a = [None, None, None] with self.assertRaisesWith(Exception, 'Data type could not be determined. Please specify dtype in ' 'DataChunkIterator init.'): DataChunkIterator(a) def test_list_none_dtype(self): """Test that DataChunkIterator has the passed-in dtype and no chunks when given a list of None. """ a = [None, None, None] dci = DataChunkIterator(a, dtype=np.dtype('int')) self.assertTupleEqual(dci.maxshape, (3,)) self.assertEqual(dci.dtype, np.dtype('int')) count = 0 for chunk in dci: pass self.assertEqual(count, 0) self.assertTupleEqual(dci.recommended_data_shape(), (3,)) self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_unbuffered_first_axis(self): """Test DataChunkIterator with numpy data, no buffering, and iterating on the first dimension. """ a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=1) count = 0 for chunk in dci: self.assertTupleEqual(chunk.shape, (1, 2, 3)) count += 1 self.assertEqual(count, 5) self.assertTupleEqual(dci.recommended_data_shape(), a.shape) self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_unbuffered_middle_axis(self): """Test DataChunkIterator with numpy data, no buffering, and iterating on a middle dimension. """ a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=1, iter_axis=1) count = 0 for chunk in dci: self.assertTupleEqual(chunk.shape, (5, 1, 3)) count += 1 self.assertEqual(count, 2) self.assertTupleEqual(dci.recommended_data_shape(), a.shape) self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_unbuffered_last_axis(self): """Test DataChunkIterator with numpy data, no buffering, and iterating on the last dimension. """ a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=1, iter_axis=2) count = 0 for chunk in dci: self.assertTupleEqual(chunk.shape, (5, 2, 1)) count += 1 self.assertEqual(count, 3) self.assertTupleEqual(dci.recommended_data_shape(), a.shape) self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_buffered_first_axis(self): """Test DataChunkIterator with numpy data, buffering, and iterating on the first dimension. """ a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=2) count = 0 for chunk in dci: if count < 2: self.assertTupleEqual(chunk.shape, (2, 2, 3)) else: self.assertTupleEqual(chunk.shape, (1, 2, 3)) count += 1 self.assertEqual(count, 3) self.assertTupleEqual(dci.recommended_data_shape(), a.shape) self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_buffered_middle_axis(self): """Test DataChunkIterator with numpy data, buffering, and iterating on a middle dimension. """ a = np.arange(45).reshape(5, 3, 3) dci = DataChunkIterator(data=a, buffer_size=2, iter_axis=1) count = 0 for chunk in dci: if count < 1: self.assertTupleEqual(chunk.shape, (5, 2, 3)) else: self.assertTupleEqual(chunk.shape, (5, 1, 3)) count += 1 self.assertEqual(count, 2) self.assertTupleEqual(dci.recommended_data_shape(), a.shape) self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_buffered_last_axis(self): """Test DataChunkIterator with numpy data, buffering, and iterating on the last dimension. """ a = np.arange(30).reshape(5, 2, 3) dci = DataChunkIterator(data=a, buffer_size=2, iter_axis=2) count = 0 for chunk in dci: if count < 1: self.assertTupleEqual(chunk.shape, (5, 2, 2)) else: self.assertTupleEqual(chunk.shape, (5, 2, 1)) count += 1 self.assertEqual(count, 2) self.assertTupleEqual(dci.recommended_data_shape(), a.shape) self.assertIsNone(dci.recommended_chunk_shape()) def test_numpy_iter_unmatched_buffer_size(self): a = np.arange(10) dci = DataChunkIterator(data=a, buffer_size=3) self.assertTupleEqual(dci.maxshape, a.shape) self.assertEqual(dci.dtype, a.dtype) count = 0 for chunk in dci: if count < 3: self.assertTupleEqual(chunk.data.shape, (3,)) else: self.assertTupleEqual(chunk.data.shape, (1,)) count += 1 self.assertEqual(count, 4) self.assertTupleEqual(dci.recommended_data_shape(), a.shape) self.assertIsNone(dci.recommended_chunk_shape()) def test_standard_iterator_unbuffered(self): dci = DataChunkIterator(data=range(10), buffer_size=1) self.assertEqual(dci.dtype, np.dtype(int)) self.assertTupleEqual(dci.maxshape, (10,)) self.assertTupleEqual(dci.recommended_data_shape(), (10,)) # Test before and after iteration count = 0 for chunk in dci: self.assertTupleEqual(chunk.data.shape, (1,)) count += 1 self.assertEqual(count, 10) self.assertTupleEqual(dci.recommended_data_shape(), (10,)) # Test before and after iteration self.assertIsNone(dci.recommended_chunk_shape()) def test_standard_iterator_unmatched_buffersized(self): dci = DataChunkIterator(data=range(10), buffer_size=3) self.assertEqual(dci.dtype, np.dtype(int)) self.assertTupleEqual(dci.maxshape, (10,)) self.assertIsNone(dci.recommended_chunk_shape()) self.assertTupleEqual(dci.recommended_data_shape(), (10,)) # Test before and after iteration count = 0 for chunk in dci: if count < 3: self.assertTupleEqual(chunk.data.shape, (3,)) else: self.assertTupleEqual(chunk.data.shape, (1,)) count += 1 self.assertEqual(count, 4) self.assertTupleEqual(dci.recommended_data_shape(), (10,)) # Test before and after iteration def test_multidimensional_list_first_axis(self): """Test DataChunkIterator with multidimensional list data, no buffering, and iterating on the first dimension. """ a = np.arange(30).reshape(5, 2, 3).tolist() dci = DataChunkIterator(a) self.assertTupleEqual(dci.maxshape, (5, 2, 3)) self.assertEqual(dci.dtype, np.dtype(int)) count = 0 for chunk in dci: self.assertTupleEqual(chunk.data.shape, (1, 2, 3)) count += 1 self.assertEqual(count, 5) self.assertTupleEqual(dci.recommended_data_shape(), (5, 2, 3)) self.assertIsNone(dci.recommended_chunk_shape()) def test_multidimensional_list_middle_axis(self): """Test DataChunkIterator with multidimensional list data, no buffering, and iterating on a middle dimension. """ a = np.arange(30).reshape(5, 2, 3).tolist() warn_msg = ('Iterating over an axis other than the first dimension of list or tuple data ' 'involves converting the data object to a numpy ndarray, which may incur a computational ' 'cost.') with self.assertWarnsWith(UserWarning, warn_msg): dci = DataChunkIterator(a, iter_axis=1) self.assertTupleEqual(dci.maxshape, (5, 2, 3)) self.assertEqual(dci.dtype, np.dtype(int)) count = 0 for chunk in dci: self.assertTupleEqual(chunk.data.shape, (5, 1, 3)) count += 1 self.assertEqual(count, 2) self.assertTupleEqual(dci.recommended_data_shape(), (5, 2, 3)) self.assertIsNone(dci.recommended_chunk_shape()) def test_multidimensional_list_last_axis(self): """Test DataChunkIterator with multidimensional list data, no buffering, and iterating on the last dimension. """ a = np.arange(30).reshape(5, 2, 3).tolist() warn_msg = ('Iterating over an axis other than the first dimension of list or tuple data ' 'involves converting the data object to a numpy ndarray, which may incur a computational ' 'cost.') with self.assertWarnsWith(UserWarning, warn_msg): dci = DataChunkIterator(a, iter_axis=2) self.assertTupleEqual(dci.maxshape, (5, 2, 3)) self.assertEqual(dci.dtype, np.dtype(int)) count = 0 for chunk in dci: self.assertTupleEqual(chunk.data.shape, (5, 2, 1)) count += 1 self.assertEqual(count, 3) self.assertTupleEqual(dci.recommended_data_shape(), (5, 2, 3)) self.assertIsNone(dci.recommended_chunk_shape()) def test_maxshape(self): a = np.arange(30).reshape(5, 2, 3) aiter = iter(a) daiter = DataChunkIterator.from_iterable(aiter, buffer_size=2) self.assertEqual(daiter.maxshape, (None, 2, 3)) def test_dtype(self): a = np.arange(30, dtype='int32').reshape(5, 2, 3) aiter = iter(a) daiter = DataChunkIterator.from_iterable(aiter, buffer_size=2) self.assertEqual(daiter.dtype, a.dtype) def test_sparse_data_buffer_aligned(self): a = [1, 2, 3, 4, None, None, 7, 8, None, None] dci = DataChunkIterator(a, buffer_size=2) self.assertTupleEqual(dci.maxshape, (10,)) self.assertEqual(dci.dtype, np.dtype(int)) count = 0 for chunk in dci: self.assertTupleEqual(chunk.data.shape, (2,)) self.assertEqual(len(chunk.selection), 1) self.assertEqual(chunk.selection[0], slice(chunk.data[0] - 1, chunk.data[1])) count += 1 self.assertEqual(count, 3) self.assertTupleEqual(dci.recommended_data_shape(), (10,)) self.assertIsNone(dci.recommended_chunk_shape()) def test_sparse_data_buffer_notaligned(self): a = [1, 2, 3, None, None, None, None, 8, 9, 10] dci = DataChunkIterator(a, buffer_size=2) self.assertTupleEqual(dci.maxshape, (10,)) self.assertEqual(dci.dtype, np.dtype(int)) count = 0 for chunk in dci: self.assertEqual(len(chunk.selection), 1) if count == 0: # [1, 2] self.assertListEqual(chunk.data.tolist(), [1, 2]) self.assertEqual(chunk.selection[0], slice(chunk.data[0] - 1, chunk.data[1])) elif count == 1: # [3, None] self.assertListEqual(chunk.data.tolist(), [3, ]) self.assertEqual(chunk.selection[0], slice(chunk.data[0] - 1, chunk.data[0])) elif count == 2: # [8, 9] self.assertListEqual(chunk.data.tolist(), [8, 9]) self.assertEqual(chunk.selection[0], slice(chunk.data[0] - 1, chunk.data[1])) else: # count == 3, [10] self.assertListEqual(chunk.data.tolist(), [10, ]) self.assertEqual(chunk.selection[0], slice(chunk.data[0] - 1, chunk.data[0])) count += 1 self.assertEqual(count, 4) self.assertTupleEqual(dci.recommended_data_shape(), (10,)) self.assertIsNone(dci.recommended_chunk_shape()) def test_start_with_none(self): a = [None, None, 3] dci = DataChunkIterator(a, buffer_size=2) self.assertTupleEqual(dci.maxshape, (3,)) self.assertEqual(dci.dtype, np.dtype(int)) count = 0 for chunk in dci: self.assertListEqual(chunk.data.tolist(), [3]) self.assertEqual(len(chunk.selection), 1) self.assertEqual(chunk.selection[0], slice(2, 3)) count += 1 self.assertEqual(count, 1) self.assertTupleEqual(dci.recommended_data_shape(), (3,)) self.assertIsNone(dci.recommended_chunk_shape()) def test_list_scalar(self): a = [3] dci = DataChunkIterator(a, buffer_size=2) self.assertTupleEqual(dci.maxshape, (1,)) self.assertEqual(dci.dtype, np.dtype(int)) count = 0 for chunk in dci: self.assertListEqual(chunk.data.tolist(), [3]) self.assertEqual(len(chunk.selection), 1) self.assertEqual(chunk.selection[0], slice(0, 1)) count += 1 self.assertEqual(count, 1) self.assertTupleEqual(dci.recommended_data_shape(), (1,)) self.assertIsNone(dci.recommended_chunk_shape()) def test_list_numpy_scalar(self): a = np.array([3]) dci = DataChunkIterator(a, buffer_size=2) self.assertTupleEqual(dci.maxshape, (1,)) self.assertEqual(dci.dtype, np.dtype(int)) count = 0 for chunk in dci: self.assertListEqual(chunk.data.tolist(), [3]) self.assertEqual(len(chunk.selection), 1) self.assertEqual(chunk.selection[0], slice(0, 1)) count += 1 self.assertEqual(count, 1) self.assertTupleEqual(dci.recommended_data_shape(), (1,)) self.assertIsNone(dci.recommended_chunk_shape()) def test_set_maxshape(self): a = np.array([3]) dci = DataChunkIterator(a, maxshape=(5, 2, 3), buffer_size=2) self.assertTupleEqual(dci.maxshape, (5, 2, 3)) self.assertEqual(dci.dtype, np.dtype(int)) count = 0 for chunk in dci: self.assertListEqual(chunk.data.tolist(), [3]) self.assertTupleEqual(chunk.selection, (slice(0, 1), slice(None), slice(None))) count += 1 self.assertEqual(count, 1) self.assertTupleEqual(dci.recommended_data_shape(), (5, 2, 3)) self.assertIsNone(dci.recommended_chunk_shape()) def test_custom_iter_first_axis(self): def my_iter(): count = 0 a = np.arange(30).reshape(5, 2, 3) while count < a.shape[0]: val = a[count, :, :] count = count + 1 yield val return dci = DataChunkIterator(data=my_iter(), buffer_size=2) count = 0 for chunk in dci: if count < 2: self.assertTupleEqual(chunk.shape, (2, 2, 3)) else: self.assertTupleEqual(chunk.shape, (1, 2, 3)) count += 1 self.assertEqual(count, 3) # self.assertTupleEqual(dci.recommended_data_shape(), (2, 2, 3)) self.assertIsNone(dci.recommended_chunk_shape()) def test_custom_iter_middle_axis(self): def my_iter(): count = 0 a = np.arange(45).reshape(5, 3, 3) while count < a.shape[1]: val = a[:, count, :] count = count + 1 yield val return dci = DataChunkIterator(data=my_iter(), buffer_size=2, iter_axis=1) count = 0 for chunk in dci: if count < 1: self.assertTupleEqual(chunk.shape, (5, 2, 3)) else: self.assertTupleEqual(chunk.shape, (5, 1, 3)) count += 1 self.assertEqual(count, 2) # self.assertTupleEqual(dci.recommended_data_shape(), (5, 2, 3)) self.assertIsNone(dci.recommended_chunk_shape()) def test_custom_iter_last_axis(self): def my_iter(): count = 0 a = np.arange(30).reshape(5, 2, 3) while count < a.shape[2]: val = a[:, :, count] count = count + 1 yield val return dci = DataChunkIterator(data=my_iter(), buffer_size=2, iter_axis=2) count = 0 for chunk in dci: if count < 1: self.assertTupleEqual(chunk.shape, (5, 2, 2)) else: self.assertTupleEqual(chunk.shape, (5, 2, 1)) count += 1 self.assertEqual(count, 2) # self.assertTupleEqual(dci.recommended_data_shape(), (5, 2, 2)) self.assertIsNone(dci.recommended_chunk_shape()) def test_custom_iter_mismatched_axis(self): def my_iter(): count = 0 a = np.arange(30).reshape(5, 2, 3) while count < a.shape[2]: val = a[:, :, count] count = count + 1 yield val return # iterator returns slices of size (5, 2) # because iter_axis is by default 0, these chunks will be placed along the first dimension dci = DataChunkIterator(data=my_iter(), buffer_size=2) count = 0 for chunk in dci: if count < 1: self.assertTupleEqual(chunk.shape, (2, 5, 2)) else: self.assertTupleEqual(chunk.shape, (1, 5, 2)) count += 1 self.assertEqual(count, 2) # self.assertTupleEqual(dci.recommended_data_shape(), (5, 2, 2)) self.assertIsNone(dci.recommended_chunk_shape()) class DataChunkTests(TestCase): def setUp(self): pass def tearDown(self): pass def test_len_operator_no_data(self): temp = DataChunk() self.assertEqual(len(temp), 0) def test_len_operator_with_data(self): temp = DataChunk(np.arange(10).reshape(5, 2)) self.assertEqual(len(temp), 5) def test_dtype(self): temp = DataChunk(np.arange(10).astype('int')) temp_dtype = temp.dtype self.assertEqual(temp_dtype, np.dtype('int')) def test_astype(self): temp1 = DataChunk(np.arange(10).reshape(5, 2)) temp2 = temp1.astype('float32') self.assertEqual(temp2.dtype, np.dtype('float32')) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/utils_test/test_core_DataIO.py0000644000655200065520000000331700000000000022627 0ustar00circlecicirclecifrom copy import copy, deepcopy import numpy as np from hdmf.container import Data from hdmf.data_utils import DataIO from hdmf.testing import TestCase class DataIOTests(TestCase): def setUp(self): pass def tearDown(self): pass def test_copy(self): obj = DataIO(data=[1., 2., 3.]) obj_copy = copy(obj) self.assertNotEqual(id(obj), id(obj_copy)) self.assertEqual(id(obj.data), id(obj_copy.data)) def test_deepcopy(self): obj = DataIO(data=[1., 2., 3.]) obj_copy = deepcopy(obj) self.assertNotEqual(id(obj), id(obj_copy)) self.assertNotEqual(id(obj.data), id(obj_copy.data)) def test_dataio_slice_delegation(self): indata = np.arange(30) dset = DataIO(indata) self.assertTrue(np.all(dset[2:15] == indata[2:15])) indata = np.arange(50).reshape(5, 10) dset = DataIO(indata) self.assertTrue(np.all(dset[1:3, 5:8] == indata[1:3, 5:8])) def test_set_dataio(self): """ Test that Data.set_dataio works as intended """ dataio = DataIO() data = np.arange(30).reshape(5, 2, 3) container = Data('wrapped_data', data) container.set_dataio(dataio) self.assertIs(dataio.data, data) self.assertIs(dataio, container.data) def test_set_dataio_data_already_set(self): """ Test that Data.set_dataio works as intended """ dataio = DataIO(data=np.arange(30).reshape(5, 2, 3)) data = np.arange(30).reshape(5, 2, 3) container = Data('wrapped_data', data) with self.assertRaisesWith(ValueError, "cannot overwrite 'data' on DataIO"): container.set_dataio(dataio) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/utils_test/test_core_ShapeValidator.py0000644000655200065520000002227100000000000024434 0ustar00circlecicircleciimport numpy as np from hdmf.common.table import DynamicTable, DynamicTableRegion, VectorData from hdmf.data_utils import ShapeValidatorResult, DataChunkIterator, assertEqualShape from hdmf.testing import TestCase class ShapeValidatorTests(TestCase): def setUp(self): pass def tearDown(self): pass def test_array_all_dimensions_match(self): # Test match d1 = np.arange(10).reshape(2, 5) d2 = np.arange(10).reshape(2, 5) res = assertEqualShape(d1, d2) self.assertTrue(res.result) self.assertIsNone(res.error) self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) self.assertTupleEqual(res.shape2, (2, 5)) self.assertTupleEqual(res.axes1, (0, 1)) self.assertTupleEqual(res.axes2, (0, 1)) def test_array_dimensions_mismatch(self): # Test unmatched d1 = np.arange(10).reshape(2, 5) d2 = np.arange(10).reshape(5, 2) res = assertEqualShape(d1, d2) self.assertFalse(res.result) self.assertEqual(res.error, 'AXIS_LEN_ERROR') self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ((0, 0), (1, 1))) self.assertTupleEqual(res.shape1, (2, 5)) self.assertTupleEqual(res.shape2, (5, 2)) self.assertTupleEqual(res.axes1, (0, 1)) self.assertTupleEqual(res.axes2, (0, 1)) def test_array_unequal_number_of_dimensions(self): # Test unequal num dims d1 = np.arange(10).reshape(2, 5) d2 = np.arange(20).reshape(5, 2, 2) res = assertEqualShape(d1, d2) self.assertFalse(res.result) self.assertEqual(res.error, 'NUM_AXES_ERROR') self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) self.assertTupleEqual(res.shape2, (5, 2, 2)) self.assertTupleEqual(res.axes1, (0, 1)) self.assertTupleEqual(res.axes2, (0, 1, 2)) def test_array_unequal_number_of_dimensions_check_one_axis_only(self): # Test unequal num dims compare one axis d1 = np.arange(10).reshape(2, 5) d2 = np.arange(20).reshape(2, 5, 2) res = assertEqualShape(d1, d2, 0, 0) self.assertTrue(res.result) self.assertIsNone(res.error) self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) self.assertTupleEqual(res.shape2, (2, 5, 2)) self.assertTupleEqual(res.axes1, (0,)) self.assertTupleEqual(res.axes2, (0,)) def test_array_unequal_number_of_dimensions_check_multiple_axesy(self): # Test unequal num dims compare multiple axes d1 = np.arange(10).reshape(2, 5) d2 = np.arange(20).reshape(5, 2, 2) res = assertEqualShape(d1, d2, [0, 1], [1, 0]) self.assertTrue(res.result) self.assertIsNone(res.error) self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) self.assertTupleEqual(res.shape2, (5, 2, 2)) self.assertTupleEqual(res.axes1, (0, 1)) self.assertTupleEqual(res.axes2, (1, 0)) def test_array_unequal_number_of_axes_for_comparison(self): # Test unequal num axes for comparison d1 = np.arange(10).reshape(2, 5) d2 = np.arange(20).reshape(5, 2, 2) res = assertEqualShape(d1, d2, [0, 1], 1) self.assertFalse(res.result) self.assertEqual(res.error, "NUM_AXES_ERROR") self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) self.assertTupleEqual(res.shape2, (5, 2, 2)) self.assertTupleEqual(res.axes1, (0, 1)) self.assertTupleEqual(res.axes2, (1,)) def test_array_axis_index_out_of_bounds_single_axis(self): # Test too large frist axis d1 = np.arange(10).reshape(2, 5) d2 = np.arange(20).reshape(5, 2, 2) res = assertEqualShape(d1, d2, 4, 1) self.assertFalse(res.result) self.assertEqual(res.error, 'AXIS_OUT_OF_BOUNDS') self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) self.assertTupleEqual(res.shape2, (5, 2, 2)) self.assertTupleEqual(res.axes1, (4,)) self.assertTupleEqual(res.axes2, (1,)) def test_array_axis_index_out_of_bounds_mutilple_axis(self): # Test too large second axis d1 = np.arange(10).reshape(2, 5) d2 = np.arange(20).reshape(5, 2, 2) res = assertEqualShape(d1, d2, [0, 1], [5, 0]) self.assertFalse(res.result) self.assertEqual(res.error, 'AXIS_OUT_OF_BOUNDS') self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) self.assertTupleEqual(res.shape2, (5, 2, 2)) self.assertTupleEqual(res.axes1, (0, 1)) self.assertTupleEqual(res.axes2, (5, 0)) def test_DataChunkIterators_match(self): # Compare data chunk iterators d1 = DataChunkIterator(data=np.arange(10).reshape(2, 5)) d2 = DataChunkIterator(data=np.arange(10).reshape(2, 5)) res = assertEqualShape(d1, d2) self.assertTrue(res.result) self.assertIsNone(res.error) self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (2, 5)) self.assertTupleEqual(res.shape2, (2, 5)) self.assertTupleEqual(res.axes1, (0, 1)) self.assertTupleEqual(res.axes2, (0, 1)) def test_DataChunkIterator_ignore_undetermined_axis(self): # Compare data chunk iterators with undetermined axis (ignore axis) d1 = DataChunkIterator(data=np.arange(10).reshape(2, 5), maxshape=(None, 5)) d2 = DataChunkIterator(data=np.arange(10).reshape(2, 5)) res = assertEqualShape(d1, d2, ignore_undetermined=True) self.assertTrue(res.result) self.assertIsNone(res.error) self.assertTupleEqual(res.ignored, ((0, 0),)) self.assertTupleEqual(res.unmatched, ()) self.assertTupleEqual(res.shape1, (None, 5)) self.assertTupleEqual(res.shape2, (2, 5)) self.assertTupleEqual(res.axes1, (0, 1)) self.assertTupleEqual(res.axes2, (0, 1)) def test_DataChunkIterator_error_on_undetermined_axis(self): # Compare data chunk iterators with undetermined axis (error on undetermined axis) d1 = DataChunkIterator(data=np.arange(10).reshape(2, 5), maxshape=(None, 5)) d2 = DataChunkIterator(data=np.arange(10).reshape(2, 5)) res = assertEqualShape(d1, d2, ignore_undetermined=False) self.assertFalse(res.result) self.assertEqual(res.error, 'AXIS_LEN_ERROR') self.assertTupleEqual(res.ignored, ()) self.assertTupleEqual(res.unmatched, ((0, 0),)) self.assertTupleEqual(res.shape1, (None, 5)) self.assertTupleEqual(res.shape2, (2, 5)) self.assertTupleEqual(res.axes1, (0, 1)) self.assertTupleEqual(res.axes2, (0, 1)) def test_DynamicTableRegion_shape_validation(self): # Create a test DynamicTable dt_spec = [ {'name': 'foo', 'description': 'foo column'}, {'name': 'bar', 'description': 'bar column'}, {'name': 'baz', 'description': 'baz column'}, ] dt_data = [ [1, 2, 3, 4, 5], [10.0, 20.0, 30.0, 40.0, 50.0], ['cat', 'dog', 'bird', 'fish', 'lizard'] ] columns = [ VectorData(name=s['name'], description=s['description'], data=d) for s, d in zip(dt_spec, dt_data) ] dt = DynamicTable("with_columns_and_data", "a test table", columns=columns) # Create test DynamicTableRegion dtr = DynamicTableRegion('dtr', [1, 2, 2], 'desc', table=dt) # Confirm that the shapes match res = assertEqualShape(dtr, np.arange(9).reshape(3, 3)) self.assertTrue(res.result) def with_table_columns(self): cols = [VectorData(**d) for d in self.spec] table = DynamicTable("with_table_columns", 'a test table', columns=cols) return table def with_columns_and_data(self): return class ShapeValidatorResultTests(TestCase): def setUp(self): pass def tearDown(self): pass def test_default_message(self): temp = ShapeValidatorResult() temp.error = 'AXIS_LEN_ERROR' self.assertEqual(temp.default_message, ShapeValidatorResult.SHAPE_ERROR[temp.error]) def test_set_error_to_illegal_type(self): temp = ShapeValidatorResult() with self.assertRaises(ValueError): temp.error = 'MY_ILLEGAL_ERROR_TYPE' def test_ensure_use_of_tuples_during_asignment(self): temp = ShapeValidatorResult() temp_d = [1, 2] temp_cases = ['shape1', 'shape2', 'axes1', 'axes2', 'ignored', 'unmatched'] for var in temp_cases: setattr(temp, var, temp_d) self.assertIsInstance(getattr(temp, var), tuple, var) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/utils_test/test_docval.py0000644000655200065520000012503300000000000021766 0ustar00circlecicircleciimport numpy as np from hdmf.testing import TestCase from hdmf.utils import (docval, fmt_docval_args, get_docval, getargs, popargs, AllowPositional, get_docval_macro, docval_macro) class MyTestClass(object): @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}) def basic_add(self, **kwargs): return kwargs @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'}) def basic_add2(self, **kwargs): return kwargs @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, {'name': 'arg2', 'type': 'int', 'doc': 'argument2 is a int'}, {'name': 'arg3', 'type': bool, 'doc': 'argument3 is a bool. it defaults to False', 'default': False}) def basic_add2_kw(self, **kwargs): return kwargs @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str', 'default': 'a'}, {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int', 'default': 1}) def basic_only_kw(self, **kwargs): return kwargs @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, {'name': 'arg2', 'type': 'int', 'doc': 'argument2 is a int'}, {'name': 'arg3', 'type': bool, 'doc': 'argument3 is a bool. it defaults to False', 'default': False}, allow_extra=True) def basic_add2_kw_allow_extra(self, **kwargs): return kwargs class MyTestSubclass(MyTestClass): @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'}) def basic_add(self, **kwargs): return kwargs @docval({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'}, {'name': 'arg3', 'type': bool, 'doc': 'argument3 is a bool. it defaults to False', 'default': False}, {'name': 'arg4', 'type': str, 'doc': 'argument4 is a str'}, {'name': 'arg5', 'type': 'float', 'doc': 'argument5 is a float'}, {'name': 'arg6', 'type': bool, 'doc': 'argument6 is a bool. it defaults to None', 'default': None}) def basic_add2_kw(self, **kwargs): return kwargs class MyChainClass(MyTestClass): @docval({'name': 'arg1', 'type': (str, 'MyChainClass'), 'doc': 'arg1 is a string or MyChainClass'}, {'name': 'arg2', 'type': ('array_data', 'MyChainClass'), 'doc': 'arg2 is array data or MyChainClass. it defaults to None', 'default': None}, {'name': 'arg3', 'type': ('array_data', 'MyChainClass'), 'doc': 'arg3 is array data or MyChainClass', 'shape': (None, 2)}, {'name': 'arg4', 'type': ('array_data', 'MyChainClass'), 'doc': 'arg3 is array data or MyChainClass. it defaults to None.', 'shape': (None, 2), 'default': None}) def __init__(self, **kwargs): self._arg1, self._arg2, self._arg3, self._arg4 = popargs('arg1', 'arg2', 'arg3', 'arg4', kwargs) @property def arg1(self): if isinstance(self._arg1, MyChainClass): return self._arg1.arg1 else: return self._arg1 @property def arg2(self): if isinstance(self._arg2, MyChainClass): return self._arg2.arg2 else: return self._arg2 @property def arg3(self): if isinstance(self._arg3, MyChainClass): return self._arg3.arg3 else: return self._arg3 @arg3.setter def arg3(self, val): self._arg3 = val @property def arg4(self): if isinstance(self._arg4, MyChainClass): return self._arg4.arg4 else: return self._arg4 @arg4.setter def arg4(self, val): self._arg4 = val class TestDocValidator(TestCase): def setUp(self): self.test_obj = MyTestClass() self.test_obj_sub = MyTestSubclass() def test_bad_type(self): exp_msg = (r"docval for arg1: error parsing argument type: argtype must be a type, " r"a str, a list, a tuple, or None - got ") with self.assertRaisesRegex(Exception, exp_msg): @docval({'name': 'arg1', 'type': {'a': 1}, 'doc': 'this is a bad type'}) def method(self, **kwargs): pass method(self, arg1=1234560) def test_bad_shape(self): @docval({'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape', 'shape': (None, 2)}) def method(self, **kwargs): pass with self.assertRaises(ValueError): method(self, arg1=[[1]]) with self.assertRaises(ValueError): method(self, arg1=[1]) # this should work method(self, arg1=[[1, 1]]) def test_multi_shape(self): @docval({'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape', 'shape': ((None,), (None, 2))}) def method1(self, **kwargs): pass method1(self, arg1=[[1, 1]]) method1(self, arg1=[1, 2]) with self.assertRaises(ValueError): method1(self, arg1=[[1, 1, 1]]) def test_fmt_docval_args(self): """ Test that fmt_docval_args works """ test_kwargs = { 'arg1': 'a string', 'arg2': 1, 'arg3': True, } rec_args, rec_kwargs = fmt_docval_args(self.test_obj.basic_add2_kw, test_kwargs) exp_args = ['a string', 1] self.assertListEqual(rec_args, exp_args) exp_kwargs = {'arg3': True} self.assertDictEqual(rec_kwargs, exp_kwargs) def test_fmt_docval_args_no_docval(self): """ Test that fmt_docval_args raises an error when run on function without docval """ def method1(self, **kwargs): pass with self.assertRaisesRegex(ValueError, r"no docval found on .*method1.*"): fmt_docval_args(method1, {}) def test_fmt_docval_args_allow_extra(self): """ Test that fmt_docval_args works """ test_kwargs = { 'arg1': 'a string', 'arg2': 1, 'arg3': True, 'hello': 'abc', 'list': ['abc', 1, 2, 3] } rec_args, rec_kwargs = fmt_docval_args(self.test_obj.basic_add2_kw_allow_extra, test_kwargs) exp_args = ['a string', 1] self.assertListEqual(rec_args, exp_args) exp_kwargs = {'arg3': True, 'hello': 'abc', 'list': ['abc', 1, 2, 3]} self.assertDictEqual(rec_kwargs, exp_kwargs) def test_docval_add(self): """Test that docval works with a single positional argument """ kwargs = self.test_obj.basic_add('a string') self.assertDictEqual(kwargs, {'arg1': 'a string'}) def test_docval_add_kw(self): """Test that docval works with a single positional argument passed as key-value """ kwargs = self.test_obj.basic_add(arg1='a string') self.assertDictEqual(kwargs, {'arg1': 'a string'}) def test_docval_add_missing_args(self): """Test that docval catches missing argument with a single positional argument """ with self.assertRaisesWith(TypeError, "MyTestClass.basic_add: missing argument 'arg1'"): self.test_obj.basic_add() def test_docval_add2(self): """Test that docval works with two positional arguments """ kwargs = self.test_obj.basic_add2('a string', 100) self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100}) def test_docval_add2_w_unicode(self): """Test that docval works with two positional arguments """ kwargs = self.test_obj.basic_add2(u'a string', 100) self.assertDictEqual(kwargs, {'arg1': u'a string', 'arg2': 100}) def test_docval_add2_kw_default(self): """Test that docval works with two positional arguments and a keyword argument when using default keyword argument value """ kwargs = self.test_obj.basic_add2_kw('a string', 100) self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': False}) def test_docval_add2_pos_as_kw(self): """Test that docval works with two positional arguments and a keyword argument when using default keyword argument value, but pass positional arguments by key-value """ kwargs = self.test_obj.basic_add2_kw(arg1='a string', arg2=100) self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': False}) def test_docval_add2_kw_kw_syntax(self): """Test that docval works with two positional arguments and a keyword argument when specifying keyword argument value with keyword syntax """ kwargs = self.test_obj.basic_add2_kw('a string', 100, arg3=True) self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': True}) def test_docval_add2_kw_all_kw_syntax(self): """Test that docval works with two positional arguments and a keyword argument when specifying all arguments by key-value """ kwargs = self.test_obj.basic_add2_kw(arg1='a string', arg2=100, arg3=True) self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': True}) def test_docval_add2_kw_pos_syntax(self): """Test that docval works with two positional arguments and a keyword argument when specifying keyword argument value with positional syntax """ kwargs = self.test_obj.basic_add2_kw('a string', 100, True) self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': True}) def test_docval_add2_kw_pos_syntax_missing_args(self): """Test that docval catches incorrect type with two positional arguments and a keyword argument when specifying keyword argument value with positional syntax """ msg = "MyTestClass.basic_add2_kw: incorrect type for 'arg2' (got 'str', expected 'int')" with self.assertRaisesWith(TypeError, msg): self.test_obj.basic_add2_kw('a string', 'bad string') def test_docval_add_sub(self): """Test that docval works with a two positional arguments, where the second is specified by the subclass implementation """ kwargs = self.test_obj_sub.basic_add('a string', 100) expected = {'arg1': 'a string', 'arg2': 100} self.assertDictEqual(kwargs, expected) def test_docval_add2_kw_default_sub(self): """Test that docval works with a four positional arguments and two keyword arguments, where two positional and one keyword argument is specified in both the parent and sublcass implementations """ kwargs = self.test_obj_sub.basic_add2_kw('a string', 100, 'another string', 200.0) expected = {'arg1': 'a string', 'arg2': 100, 'arg4': 'another string', 'arg5': 200.0, 'arg3': False, 'arg6': None} self.assertDictEqual(kwargs, expected) def test_docval_add2_kw_default_sub_missing_args(self): """Test that docval catches missing arguments with a four positional arguments and two keyword arguments, where two positional and one keyword argument is specified in both the parent and sublcass implementations, when using default values for keyword arguments """ with self.assertRaisesWith(TypeError, "MyTestSubclass.basic_add2_kw: missing argument 'arg5'"): self.test_obj_sub.basic_add2_kw('a string', 100, 'another string') def test_docval_add2_kw_kwsyntax_sub(self): """Test that docval works when called with a four positional arguments and two keyword arguments, where two positional and one keyword argument is specified in both the parent and sublcass implementations """ kwargs = self.test_obj_sub.basic_add2_kw('a string', 100, 'another string', 200.0, arg6=True) expected = {'arg1': 'a string', 'arg2': 100, 'arg4': 'another string', 'arg5': 200.0, 'arg3': False, 'arg6': True} self.assertDictEqual(kwargs, expected) def test_docval_add2_kw_kwsyntax_sub_missing_args(self): """Test that docval catches missing arguments when called with a four positional arguments and two keyword arguments, where two positional and one keyword argument is specified in both the parent and sublcass implementations """ with self.assertRaisesWith(TypeError, "MyTestSubclass.basic_add2_kw: missing argument 'arg5'"): self.test_obj_sub.basic_add2_kw('a string', 100, 'another string', arg6=True) def test_docval_add2_kw_kwsyntax_sub_nonetype_arg(self): """Test that docval catches NoneType when called with a four positional arguments and two keyword arguments, where two positional and one keyword argument is specified in both the parent and sublcass implementations """ msg = "MyTestSubclass.basic_add2_kw: None is not allowed for 'arg5' (expected 'float', not None)" with self.assertRaisesWith(TypeError, msg): self.test_obj_sub.basic_add2_kw('a string', 100, 'another string', None, arg6=True) def test_only_kw_no_args(self): """Test that docval parses arguments when only keyword arguments exist, and no arguments are specified """ kwargs = self.test_obj.basic_only_kw() self.assertDictEqual(kwargs, {'arg1': 'a', 'arg2': 1}) def test_only_kw_arg1_no_arg2(self): """Test that docval parses arguments when only keyword arguments exist, and only first argument is specified as key-value """ kwargs = self.test_obj.basic_only_kw(arg1='b') self.assertDictEqual(kwargs, {'arg1': 'b', 'arg2': 1}) def test_only_kw_arg1_pos_no_arg2(self): """Test that docval parses arguments when only keyword arguments exist, and only first argument is specified as positional argument """ kwargs = self.test_obj.basic_only_kw('b') self.assertDictEqual(kwargs, {'arg1': 'b', 'arg2': 1}) def test_only_kw_arg2_no_arg1(self): """Test that docval parses arguments when only keyword arguments exist, and only second argument is specified as key-value """ kwargs = self.test_obj.basic_only_kw(arg2=2) self.assertDictEqual(kwargs, {'arg1': 'a', 'arg2': 2}) def test_only_kw_arg1_arg2(self): """Test that docval parses arguments when only keyword arguments exist, and both arguments are specified as key-value """ kwargs = self.test_obj.basic_only_kw(arg1='b', arg2=2) self.assertDictEqual(kwargs, {'arg1': 'b', 'arg2': 2}) def test_only_kw_arg1_arg2_pos(self): """Test that docval parses arguments when only keyword arguments exist, and both arguments are specified as positional arguments """ kwargs = self.test_obj.basic_only_kw('b', 2) self.assertDictEqual(kwargs, {'arg1': 'b', 'arg2': 2}) def test_extra_kwarg(self): """Test that docval parses arguments when only keyword arguments exist, and both arguments are specified as positional arguments """ with self.assertRaises(TypeError): self.test_obj.basic_add2_kw('a string', 100, bar=1000) def test_extra_args_pos_only(self): """Test that docval raises an error if too many positional arguments are specified """ msg = ("MyTestClass.basic_add2_kw: Expected at most 3 arguments ['arg1', 'arg2', 'arg3'], got 4: 4 positional " "and 0 keyword []") with self.assertRaisesWith(TypeError, msg): self.test_obj.basic_add2_kw('a string', 100, True, 'extra') def test_extra_args_pos_kw(self): """Test that docval raises an error if too many positional arguments are specified and a keyword arg is specified """ msg = ("MyTestClass.basic_add2_kw: Expected at most 3 arguments ['arg1', 'arg2', 'arg3'], got 4: 3 positional " "and 1 keyword ['arg3']") with self.assertRaisesWith(TypeError, msg): self.test_obj.basic_add2_kw('a string', 'extra', 100, arg3=True) def test_extra_kwargs_pos_kw(self): """Test that docval raises an error if extra keyword arguments are specified """ msg = ("MyTestClass.basic_add2_kw: Expected at most 3 arguments ['arg1', 'arg2', 'arg3'], got 4: 2 positional " "and 2 keyword ['arg3', 'extra']") with self.assertRaisesWith(TypeError, msg): self.test_obj.basic_add2_kw('a string', 100, extra='extra', arg3=True) def test_extra_args_pos_only_ok(self): """Test that docval raises an error if too many positional arguments are specified even if allow_extra is True """ msg = ("MyTestClass.basic_add2_kw_allow_extra: Expected at most 3 arguments ['arg1', 'arg2', 'arg3'], got " "4 positional") with self.assertRaisesWith(TypeError, msg): self.test_obj.basic_add2_kw_allow_extra('a string', 100, True, 'extra', extra='extra') def test_extra_args_pos_kw_ok(self): """Test that docval does not raise an error if too many keyword arguments are specified and allow_extra is True """ kwargs = self.test_obj.basic_add2_kw_allow_extra('a string', 100, True, extra='extra') self.assertDictEqual(kwargs, {'arg1': 'a string', 'arg2': 100, 'arg3': True, 'extra': 'extra'}) def test_dup_kw(self): """Test that docval raises an error if a keyword argument captures a positional argument before all positional arguments have been resolved """ with self.assertRaisesWith(TypeError, "MyTestClass.basic_add2_kw: got multiple values for argument 'arg1'"): self.test_obj.basic_add2_kw('a string', 100, arg1='extra') def test_extra_args_dup_kw(self): """Test that docval raises an error if a keyword argument captures a positional argument before all positional arguments have been resolved and allow_extra is True """ msg = "MyTestClass.basic_add2_kw_allow_extra: got multiple values for argument 'arg1'" with self.assertRaisesWith(TypeError, msg): self.test_obj.basic_add2_kw_allow_extra('a string', 100, True, arg1='extra') def test_unsupported_docval_term(self): """Test that docval does not allow setting of arguments marked as unsupported """ msg = "docval for arg1: keys ['unsupported'] are not supported by docval" with self.assertRaisesWith(Exception, msg): @docval({'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape', 'unsupported': 'hi!'}) def method(self, **kwargs): pass def test_catch_dup_names(self): """Test that docval does not allow duplicate argument names """ @docval({'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape'}, {'name': 'arg1', 'type': 'array_data', 'doc': 'this is a bad shape2'}) def method(self, **kwargs): pass msg = "TestDocValidator.test_catch_dup_names..method: The following names are duplicated: ['arg1']" with self.assertRaisesWith(ValueError, msg): method(self, arg1=[1]) def test_get_docval_all(self): """Test that get_docval returns a tuple of the docval arguments """ args = get_docval(self.test_obj.basic_add2) self.assertTupleEqual(args, ({'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'}, {'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'})) def test_get_docval_one_arg(self): """Test that get_docval returns the matching docval argument """ arg = get_docval(self.test_obj.basic_add2, 'arg2') self.assertTupleEqual(arg, ({'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'},)) def test_get_docval_two_args(self): """Test that get_docval returns the matching docval arguments in order """ args = get_docval(self.test_obj.basic_add2, 'arg2', 'arg1') self.assertTupleEqual(args, ({'name': 'arg2', 'type': int, 'doc': 'argument2 is a int'}, {'name': 'arg1', 'type': str, 'doc': 'argument1 is a str'})) def test_get_docval_missing_arg(self): """Test that get_docval throws error if the matching docval argument is not found """ with self.assertRaisesWith(ValueError, "Function basic_add2 does not have docval argument 'arg3'"): get_docval(self.test_obj.basic_add2, 'arg3') def test_get_docval_missing_args(self): """Test that get_docval throws error if the matching docval arguments is not found """ with self.assertRaisesWith(ValueError, "Function basic_add2 does not have docval argument 'arg3'"): get_docval(self.test_obj.basic_add2, 'arg3', 'arg4') def test_get_docval_missing_arg_of_many_ok(self): """Test that get_docval throws error if the matching docval arguments is not found """ with self.assertRaisesWith(ValueError, "Function basic_add2 does not have docval argument 'arg3'"): get_docval(self.test_obj.basic_add2, 'arg2', 'arg3') def test_get_docval_none(self): """Test that get_docval returns an empty tuple if there is no docval """ args = get_docval(self.test_obj.__init__) self.assertTupleEqual(args, tuple()) def test_get_docval_none_arg(self): """Test that get_docval throws error if there is no docval and an argument name is passed """ with self.assertRaisesWith(ValueError, 'Function __init__ has no docval arguments'): get_docval(self.test_obj.__init__, 'arg3') def test_bool_type(self): @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool'}) def method(self, **kwargs): return popargs('arg1', kwargs) res = method(self, arg1=True) self.assertEqual(res, True) self.assertIsInstance(res, bool) res = method(self, arg1=np.bool_(True)) self.assertEqual(res, np.bool_(True)) self.assertIsInstance(res, np.bool_) def test_bool_string_type(self): @docval({'name': 'arg1', 'type': 'bool', 'doc': 'this is a bool'}) def method(self, **kwargs): return popargs('arg1', kwargs) res = method(self, arg1=True) self.assertEqual(res, True) self.assertIsInstance(res, bool) res = method(self, arg1=np.bool_(True)) self.assertEqual(res, np.bool_(True)) self.assertIsInstance(res, np.bool_) def test_uint_type(self): """Test that docval type specification of np.uint32 works as expected.""" @docval({'name': 'arg1', 'type': np.uint32, 'doc': 'this is a uint'}) def method(self, **kwargs): return popargs('arg1', kwargs) res = method(self, arg1=np.uint32(1)) self.assertEqual(res, np.uint32(1)) self.assertIsInstance(res, np.uint32) msg = ("TestDocValidator.test_uint_type..method: incorrect type for 'arg1' (got 'uint8', expected " "'uint32')") with self.assertRaisesWith(TypeError, msg): method(self, arg1=np.uint8(1)) msg = ("TestDocValidator.test_uint_type..method: incorrect type for 'arg1' (got 'uint64', expected " "'uint32')") with self.assertRaisesWith(TypeError, msg): method(self, arg1=np.uint64(1)) def test_uint_string_type(self): """Test that docval type specification of string 'uint' matches np.uint of all available precisions.""" @docval({'name': 'arg1', 'type': 'uint', 'doc': 'this is a uint'}) def method(self, **kwargs): return popargs('arg1', kwargs) res = method(self, arg1=np.uint(1)) self.assertEqual(res, np.uint(1)) self.assertIsInstance(res, np.uint) res = method(self, arg1=np.uint8(1)) self.assertEqual(res, np.uint8(1)) self.assertIsInstance(res, np.uint8) res = method(self, arg1=np.uint16(1)) self.assertEqual(res, np.uint16(1)) self.assertIsInstance(res, np.uint16) res = method(self, arg1=np.uint32(1)) self.assertEqual(res, np.uint32(1)) self.assertIsInstance(res, np.uint32) res = method(self, arg1=np.uint64(1)) self.assertEqual(res, np.uint64(1)) self.assertIsInstance(res, np.uint64) def test_allow_positional_warn(self): @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool'}, allow_positional=AllowPositional.WARNING) def method(self, **kwargs): return popargs('arg1', kwargs) # check that supplying a keyword arg is OK res = method(self, arg1=True) self.assertEqual(res, True) self.assertIsInstance(res, bool) # check that supplying a positional arg raises a warning msg = ('TestDocValidator.test_allow_positional_warn..method: ' 'Positional arguments are discouraged and may be forbidden in a future release.') with self.assertWarnsWith(FutureWarning, msg): method(self, True) def test_allow_positional_error(self): @docval({'name': 'arg1', 'type': bool, 'doc': 'this is a bool'}, allow_positional=AllowPositional.ERROR) def method(self, **kwargs): return popargs('arg1', kwargs) # check that supplying a keyword arg is OK res = method(self, arg1=True) self.assertEqual(res, True) self.assertIsInstance(res, bool) # check that supplying a positional arg raises an error msg = ('TestDocValidator.test_allow_positional_error..method: ' 'Only keyword arguments (e.g., func(argname=value, ...)) are allowed.') with self.assertRaisesWith(SyntaxError, msg): method(self, True) def test_enum_str(self): """Test that the basic usage of an enum check on strings works""" @docval({'name': 'arg1', 'type': str, 'doc': 'an arg', 'enum': ['a', 'b']}) # also use enum: list def method(self, **kwargs): return popargs('arg1', kwargs) self.assertEqual(method(self, 'a'), 'a') self.assertEqual(method(self, 'b'), 'b') msg = ("TestDocValidator.test_enum_str..method: " "forbidden value for 'arg1' (got 'c', expected ['a', 'b'])") with self.assertRaisesWith(ValueError, msg): method(self, 'c') def test_enum_int(self): """Test that the basic usage of an enum check on ints works""" @docval({'name': 'arg1', 'type': int, 'doc': 'an arg', 'enum': (1, 2)}) def method(self, **kwargs): return popargs('arg1', kwargs) self.assertEqual(method(self, 1), 1) self.assertEqual(method(self, 2), 2) msg = ("TestDocValidator.test_enum_int..method: " "forbidden value for 'arg1' (got 3, expected (1, 2))") with self.assertRaisesWith(ValueError, msg): method(self, 3) def test_enum_uint(self): """Test that the basic usage of an enum check on uints works""" @docval({'name': 'arg1', 'type': np.uint, 'doc': 'an arg', 'enum': (np.uint(1), np.uint(2))}) def method(self, **kwargs): return popargs('arg1', kwargs) self.assertEqual(method(self, np.uint(1)), np.uint(1)) self.assertEqual(method(self, np.uint(2)), np.uint(2)) msg = ("TestDocValidator.test_enum_uint..method: " "forbidden value for 'arg1' (got 3, expected (1, 2))") with self.assertRaisesWith(ValueError, msg): method(self, np.uint(3)) def test_enum_float(self): """Test that the basic usage of an enum check on floats works""" @docval({'name': 'arg1', 'type': float, 'doc': 'an arg', 'enum': (3.14, )}) def method(self, **kwargs): return popargs('arg1', kwargs) self.assertEqual(method(self, 3.14), 3.14) msg = ("TestDocValidator.test_enum_float..method: " "forbidden value for 'arg1' (got 3.0, expected (3.14,))") with self.assertRaisesWith(ValueError, msg): method(self, 3.) def test_enum_bool_mixed(self): """Test that the basic usage of an enum check on a tuple of bool, int, float, and string works""" @docval({'name': 'arg1', 'type': (bool, int, float, str, np.uint), 'doc': 'an arg', 'enum': (True, 1, 1.0, 'true', np.uint(1))}) def method(self, **kwargs): return popargs('arg1', kwargs) self.assertEqual(method(self, True), True) self.assertEqual(method(self, 1), 1) self.assertEqual(method(self, 1.0), 1.0) self.assertEqual(method(self, 'true'), 'true') self.assertEqual(method(self, np.uint(1)), np.uint(1)) msg = ("TestDocValidator.test_enum_bool_mixed..method: " "forbidden value for 'arg1' (got 0, expected (True, 1, 1.0, 'true', 1))") with self.assertRaisesWith(ValueError, msg): method(self, 0) def test_enum_bad_type(self): """Test that docval with an enum check where the arg type includes an invalid enum type fails""" msg = ("docval for arg1: enum checking cannot be used with arg type (, , " ", , )") with self.assertRaisesWith(Exception, msg): @docval({'name': 'arg1', 'type': (bool, int, str, np.float64, object), 'doc': 'an arg', 'enum': (1, 2)}) def method(self, **kwargs): return popargs('arg1', kwargs) def test_enum_none_type(self): """Test that the basic usage of an enum check on None works""" msg = ("docval for arg1: enum checking cannot be used with arg type None") with self.assertRaisesWith(Exception, msg): @docval({'name': 'arg1', 'type': None, 'doc': 'an arg', 'enum': (True, 1, 'true')}) def method(self, **kwargs): pass def test_enum_single_allowed(self): """Test that docval with an enum check on a single value fails""" msg = ("docval for arg1: enum value must be a list or tuple (received )") with self.assertRaisesWith(Exception, msg): @docval({'name': 'arg1', 'type': str, 'doc': 'an arg', 'enum': 'only one value'}) def method(self, **kwargs): pass def test_enum_str_default(self): """Test that docval with an enum check on strings and a default value works""" @docval({'name': 'arg1', 'type': str, 'doc': 'an arg', 'default': 'a', 'enum': ['a', 'b']}) def method(self, **kwargs): return popargs('arg1', kwargs) self.assertEqual(method(self), 'a') msg = ("TestDocValidator.test_enum_str_default..method: " "forbidden value for 'arg1' (got 'c', expected ['a', 'b'])") with self.assertRaisesWith(ValueError, msg): method(self, 'c') def test_enum_str_none_default(self): """Test that docval with an enum check on strings and a None default value works""" @docval({'name': 'arg1', 'type': str, 'doc': 'an arg', 'default': None, 'enum': ['a', 'b']}) def method(self, **kwargs): return popargs('arg1', kwargs) self.assertIsNone(method(self)) def test_enum_forbidden_values(self): """Test that docval with enum values that include a forbidden type fails""" msg = ("docval for arg1: enum values are of types not allowed by arg type " "(got [, ], expected )") with self.assertRaisesWith(Exception, msg): @docval({'name': 'arg1', 'type': bool, 'doc': 'an arg', 'enum': (True, [])}) def method(self, **kwargs): pass class TestDocValidatorChain(TestCase): def setUp(self): self.obj1 = MyChainClass('base', [[1, 2], [3, 4], [5, 6]], [[10, 20]]) # note that self.obj1.arg3 == [[1, 2], [3, 4], [5, 6]] def test_type_arg(self): """Test that passing an object for an argument that allows a specific type works""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]]) self.assertEqual(obj2.arg1, 'base') def test_type_arg_wrong_type(self): """Test that passing an object for an argument that does not match a specific type raises an error""" err_msg = "MyChainClass.__init__: incorrect type for 'arg1' (got 'object', expected 'str or MyChainClass')" with self.assertRaisesWith(TypeError, err_msg): MyChainClass(object(), [[10, 20], [30, 40], [50, 60]], [[10, 20]]) def test_shape_valid_unpack(self): """Test that passing an object for an argument with required shape tests the shape of object.argument""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]]) obj3 = MyChainClass(self.obj1, obj2, [[100, 200]]) self.assertListEqual(obj3.arg3, obj2.arg3) def test_shape_invalid_unpack(self): """Test that passing an object for an argument with required shape and object.argument has an invalid shape raises an error""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]]) # change arg3 of obj2 to fail the required shape - contrived, but could happen because datasets can change # shape after an object is initialized obj2.arg3 = [10, 20, 30] err_msg = "MyChainClass.__init__: incorrect shape for 'arg3' (got '(3,)', expected '(None, 2)')" with self.assertRaisesWith(ValueError, err_msg): MyChainClass(self.obj1, obj2, [[100, 200]]) def test_shape_none_unpack(self): """Test that passing an object for an argument with required shape and object.argument is None is OK""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]]) obj2.arg3 = None obj3 = MyChainClass(self.obj1, obj2, [[100, 200]]) self.assertIsNone(obj3.arg3) def test_shape_other_unpack(self): """Test that passing an object for an argument with required shape and object.argument is an object without an argument attribute raises an error""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], [[10, 20]]) obj2.arg3 = object() err_msg = (r"cannot check shape of object '' for argument 'arg3' " r"\(expected shape '\(None, 2\)'\)") with self.assertRaisesRegex(ValueError, err_msg): MyChainClass(self.obj1, obj2, [[100, 200]]) def test_shape_valid_unpack_default(self): """Test that passing an object for an argument with required shape and a default value tests the shape of object.argument""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], arg4=[[10, 20]]) obj3 = MyChainClass(self.obj1, [[100, 200], [300, 400], [500, 600]], arg4=obj2) self.assertListEqual(obj3.arg4, obj2.arg4) def test_shape_invalid_unpack_default(self): """Test that passing an object for an argument with required shape and a default value and object.argument has an invalid shape raises an error""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], arg4=[[10, 20]]) # change arg3 of obj2 to fail the required shape - contrived, but could happen because datasets can change # shape after an object is initialized obj2.arg4 = [10, 20, 30] err_msg = "MyChainClass.__init__: incorrect shape for 'arg4' (got '(3,)', expected '(None, 2)')" with self.assertRaisesWith(ValueError, err_msg): MyChainClass(self.obj1, [[100, 200], [300, 400], [500, 600]], arg4=obj2) def test_shape_none_unpack_default(self): """Test that passing an object for an argument with required shape and a default value and object.argument is an object without an argument attribute raises an error""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], arg4=[[10, 20]]) # change arg3 of obj2 to fail the required shape - contrived, but could happen because datasets can change # shape after an object is initialized obj2.arg4 = None obj3 = MyChainClass(self.obj1, [[100, 200], [300, 400], [500, 600]], arg4=obj2) self.assertIsNone(obj3.arg4) def test_shape_other_unpack_default(self): """Test that passing an object for an argument with required shape and a default value and object.argument is None is OK""" obj2 = MyChainClass(self.obj1, [[10, 20], [30, 40], [50, 60]], arg4=[[10, 20]]) # change arg3 of obj2 to fail the required shape - contrived, but could happen because datasets can change # shape after an object is initialized obj2.arg4 = object() err_msg = (r"cannot check shape of object '' for argument 'arg4' " r"\(expected shape '\(None, 2\)'\)") with self.assertRaisesRegex(ValueError, err_msg): MyChainClass(self.obj1, [[100, 200], [300, 400], [500, 600]], arg4=obj2) class TestGetargs(TestCase): """Test the getargs function and its error conditions.""" def test_one_arg_first(self): kwargs = {'a': 1, 'b': None} expected_kwargs = kwargs.copy() res = getargs('a', kwargs) self.assertEqual(res, 1) self.assertDictEqual(kwargs, expected_kwargs) def test_one_arg_second(self): kwargs = {'a': 1, 'b': None} expected_kwargs = kwargs.copy() res = getargs('b', kwargs) self.assertEqual(res, None) self.assertDictEqual(kwargs, expected_kwargs) def test_many_args_get_some(self): kwargs = {'a': 1, 'b': None, 'c': 3} expected_kwargs = kwargs.copy() res = getargs('a', 'c', kwargs) self.assertListEqual(res, [1, 3]) self.assertDictEqual(kwargs, expected_kwargs) def test_many_args_get_all(self): kwargs = {'a': 1, 'b': None, 'c': 3} expected_kwargs = kwargs.copy() res = getargs('a', 'b', 'c', kwargs) self.assertListEqual(res, [1, None, 3]) self.assertDictEqual(kwargs, expected_kwargs) def test_many_args_reverse(self): kwargs = {'a': 1, 'b': None, 'c': 3} expected_kwargs = kwargs.copy() res = getargs('c', 'b', 'a', kwargs) self.assertListEqual(res, [3, None, 1]) self.assertDictEqual(kwargs, expected_kwargs) def test_many_args_unpack(self): kwargs = {'a': 1, 'b': None, 'c': 3} expected_kwargs = kwargs.copy() res1, res2, res3 = getargs('a', 'b', 'c', kwargs) self.assertEqual(res1, 1) self.assertEqual(res2, None) self.assertEqual(res3, 3) self.assertDictEqual(kwargs, expected_kwargs) def test_too_few_args(self): kwargs = {'a': 1, 'b': None} msg = 'Must supply at least one key and a dict' with self.assertRaisesWith(ValueError, msg): getargs(kwargs) def test_last_arg_not_dict(self): kwargs = {'a': 1, 'b': None} msg = 'Last argument must be a dict' with self.assertRaisesWith(ValueError, msg): getargs(kwargs, 'a') def test_arg_not_found_one_arg(self): kwargs = {'a': 1, 'b': None} msg = "Argument not found in dict: 'c'" with self.assertRaisesWith(ValueError, msg): getargs('c', kwargs) def test_arg_not_found_many_args(self): kwargs = {'a': 1, 'b': None} msg = "Argument not found in dict: 'c'" with self.assertRaisesWith(ValueError, msg): getargs('a', 'c', kwargs) class TestPopargs(TestCase): """Test the popargs function and its error conditions.""" def test_one_arg_first(self): kwargs = {'a': 1, 'b': None} res = popargs('a', kwargs) self.assertEqual(res, 1) self.assertDictEqual(kwargs, {'b': None}) def test_one_arg_second(self): kwargs = {'a': 1, 'b': None} res = popargs('b', kwargs) self.assertEqual(res, None) self.assertDictEqual(kwargs, {'a': 1}) def test_many_args_pop_some(self): kwargs = {'a': 1, 'b': None, 'c': 3} res = popargs('a', 'c', kwargs) self.assertListEqual(res, [1, 3]) self.assertDictEqual(kwargs, {'b': None}) def test_many_args_pop_all(self): kwargs = {'a': 1, 'b': None, 'c': 3} res = popargs('a', 'b', 'c', kwargs) self.assertListEqual(res, [1, None, 3]) self.assertDictEqual(kwargs, {}) def test_many_args_reverse(self): kwargs = {'a': 1, 'b': None, 'c': 3} res = popargs('c', 'b', 'a', kwargs) self.assertListEqual(res, [3, None, 1]) self.assertDictEqual(kwargs, {}) def test_many_args_unpack(self): kwargs = {'a': 1, 'b': None, 'c': 3} res1, res2, res3 = popargs('a', 'b', 'c', kwargs) self.assertEqual(res1, 1) self.assertEqual(res2, None) self.assertEqual(res3, 3) self.assertDictEqual(kwargs, {}) def test_too_few_args(self): kwargs = {'a': 1, 'b': None} msg = 'Must supply at least one key and a dict' with self.assertRaisesWith(ValueError, msg): popargs(kwargs) def test_last_arg_not_dict(self): kwargs = {'a': 1, 'b': None} msg = 'Last argument must be a dict' with self.assertRaisesWith(ValueError, msg): popargs(kwargs, 'a') def test_arg_not_found_one_arg(self): kwargs = {'a': 1, 'b': None} msg = "Argument not found in dict: 'c'" with self.assertRaisesWith(ValueError, msg): popargs('c', kwargs) def test_arg_not_found_many_args(self): kwargs = {'a': 1, 'b': None} msg = "Argument not found in dict: 'c'" with self.assertRaisesWith(ValueError, msg): popargs('a', 'c', kwargs) class TestMacro(TestCase): def test_macro(self): self.assertTrue(isinstance(get_docval_macro(), dict)) self.assertSetEqual(set(get_docval_macro().keys()), {'array_data', 'scalar_data', 'data'}) self.assertTupleEqual(get_docval_macro('scalar_data'), (str, int, float, bytes, bool)) @docval_macro('scalar_data') class Dummy1: pass self.assertTupleEqual(get_docval_macro('scalar_data'), (str, int, float, bytes, bool, Dummy1)) @docval_macro('dummy') class Dummy2: pass self.assertTupleEqual(get_docval_macro('dummy'), (Dummy2, )) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/utils_test/test_labelleddict.py0000644000655200065520000002466200000000000023134 0ustar00circlecicirclecifrom hdmf.testing import TestCase from hdmf.utils import LabelledDict class MyTestClass: def __init__(self, prop1, prop2): self._prop1 = prop1 self._prop2 = prop2 @property def prop1(self): return self._prop1 @property def prop2(self): return self._prop2 class TestLabelledDict(TestCase): def test_constructor(self): """Test that constructor sets arguments properly.""" ld = LabelledDict(label='all_objects', key_attr='prop1') self.assertEqual(ld.label, 'all_objects') self.assertEqual(ld.key_attr, 'prop1') def test_constructor_default(self): """Test that constructor sets default key attribute.""" ld = LabelledDict(label='all_objects') self.assertEqual(ld.key_attr, 'name') def test_set_key_attr(self): """Test that the key attribute cannot be set after initialization.""" ld = LabelledDict(label='all_objects') with self.assertRaisesWith(AttributeError, "can't set attribute"): ld.key_attr = 'another_name' def test_getitem_unknown_val(self): """Test that dict[unknown_key] where the key unknown_key is not in the dict raises an error.""" ld = LabelledDict(label='all_objects', key_attr='prop1') with self.assertRaisesWith(KeyError, "'unknown_key'"): ld['unknown_key'] def test_getitem_eqeq_unknown_val(self): """Test that dict[unknown_attr == val] where there are no query matches returns an empty set.""" ld = LabelledDict(label='all_objects', key_attr='prop1') self.assertSetEqual(ld['unknown_attr == val'], set()) def test_getitem_eqeq_other_key(self): """Test that dict[other_attr == val] where there are no query matches returns an empty set.""" ld = LabelledDict(label='all_objects', key_attr='prop1') self.assertSetEqual(ld['prop2 == val'], set()) def test_getitem_eqeq_no_key_attr(self): """Test that dict[key_attr == val] raises an error if key_attr is not given.""" ld = LabelledDict(label='all_objects', key_attr='prop1') with self.assertRaisesWith(ValueError, "An attribute name is required before '=='."): ld[' == unknown_key'] def test_getitem_eqeq_no_val(self): """Test that dict[key_attr == val] raises an error if val is not given.""" ld = LabelledDict(label='all_objects', key_attr='prop1') with self.assertRaisesWith(ValueError, "A value is required after '=='."): ld['prop1 == '] def test_getitem_eqeq_no_key_attr_no_val(self): """Test that dict[key_attr == val] raises an error if key_attr is not given and val is not given.""" ld = LabelledDict(label='all_objects', key_attr='prop1') with self.assertRaisesWith(ValueError, "An attribute name is required before '=='."): ld[' == '] def test_add_basic(self): """Test add method on object with correct key_attr.""" ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') ld.add(obj1) self.assertIs(ld['a'], obj1) def test_add_value_missing_key(self): """Test that add raises an error if the value being set does not have the attribute key_attr.""" ld = LabelledDict(label='all_objects', key_attr='unknown_key') obj1 = MyTestClass('a', 'b') err_msg = r"Cannot set value '<.*>' in LabelledDict\. Value must have attribute 'unknown_key'\." with self.assertRaisesRegex(ValueError, err_msg): ld.add(obj1) def test_setitem_getitem_basic(self): """Test that setitem and getitem properly set and get the object.""" ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') ld.add(obj1) self.assertIs(ld['a'], obj1) def test_setitem_value_missing_key(self): """Test that setitem raises an error if the value being set does not have the attribute key_attr.""" ld = LabelledDict(label='all_objects', key_attr='unknown_key') obj1 = MyTestClass('a', 'b') err_msg = r"Cannot set value '<.*>' in LabelledDict\. Value must have attribute 'unknown_key'\." with self.assertRaisesRegex(ValueError, err_msg): ld['a'] = obj1 def test_setitem_value_inconsistent_key(self): """Test that setitem raises an error if the value being set has an inconsistent key.""" ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') err_msg = r"Key 'b' must equal attribute 'prop1' of '<.*>'\." with self.assertRaisesRegex(KeyError, err_msg): ld['b'] = obj1 def test_setitem_value_duplicate_key(self): """Test that setitem raises an error if the key already exists in the dict.""" ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') obj2 = MyTestClass('a', 'c') ld['a'] = obj1 err_msg = "Key 'a' is already in this dict. Cannot reset items in a LabelledDict." with self.assertRaisesWith(TypeError, err_msg): ld['a'] = obj2 def test_add_callable(self): """Test that add properly adds the object and calls the add_callable function.""" self.signal = None def func(v): self.signal = v ld = LabelledDict(label='all_objects', key_attr='prop1', add_callable=func) obj1 = MyTestClass('a', 'b') ld.add(obj1) self.assertIs(ld['a'], obj1) self.assertIs(self.signal, obj1) def test_setitem_callable(self): """Test that setitem properly sets the object and calls the add_callable function.""" self.signal = None def func(v): self.signal = v ld = LabelledDict(label='all_objects', key_attr='prop1', add_callable=func) obj1 = MyTestClass('a', 'b') ld['a'] = obj1 self.assertIs(ld['a'], obj1) self.assertIs(self.signal, obj1) def test_getitem_eqeq_nonempty(self): """Test that dict[key_attr == val] returns the single matching object.""" ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') ld.add(obj1) self.assertIs(ld['prop1 == a'], obj1) def test_getitem_eqeq_nonempty_key_attr_no_match(self): """Test that dict[key_attr == unknown_val] where a matching value is not found raises a KeyError.""" ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') ld.add(obj1) with self.assertRaisesWith(KeyError, "'unknown_val'"): ld['prop1 == unknown_val'] # same as ld['unknown_val'] def test_getitem_eqeq_nonempty_unknown_attr(self): """Test that dict[unknown_attr == val] where unknown_attr is not a field on the values raises an error.""" ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') ld['a'] = obj1 self.assertSetEqual(ld['unknown_attr == unknown_val'], set()) def test_getitem_nonempty_other_key(self): """Test that dict[other_key == val] returns a set of matching objects.""" ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') obj2 = MyTestClass('d', 'b') obj3 = MyTestClass('f', 'e') ld.add(obj1) ld.add(obj2) ld.add(obj3) self.assertSetEqual(ld['prop2 == b'], {obj1, obj2}) def test_pop_nocallback(self): ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') ld.add(obj1) ret = ld.pop('a') self.assertEqual(ret, obj1) self.assertEqual(ld, dict()) def test_pop_callback(self): self.signal = None def func(v): self.signal = v ld = LabelledDict(label='all_objects', key_attr='prop1', remove_callable=func) obj1 = MyTestClass('a', 'b') ld.add(obj1) ret = ld.pop('a') self.assertEqual(ret, obj1) self.assertEqual(self.signal, obj1) self.assertEqual(ld, dict()) def test_popitem_nocallback(self): ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') ld.add(obj1) ret = ld.popitem() self.assertEqual(ret, ('a', obj1)) self.assertEqual(ld, dict()) def test_popitem_callback(self): self.signal = None def func(v): self.signal = v ld = LabelledDict(label='all_objects', key_attr='prop1', remove_callable=func) obj1 = MyTestClass('a', 'b') ld.add(obj1) ret = ld.popitem() self.assertEqual(ret, ('a', obj1)) self.assertEqual(self.signal, obj1) self.assertEqual(ld, dict()) def test_clear_nocallback(self): ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') obj2 = MyTestClass('d', 'b') ld.add(obj1) ld.add(obj2) ld.clear() self.assertEqual(ld, dict()) def test_clear_callback(self): self.signal = set() def func(v): self.signal.add(v) ld = LabelledDict(label='all_objects', key_attr='prop1', remove_callable=func) obj1 = MyTestClass('a', 'b') obj2 = MyTestClass('d', 'b') ld.add(obj1) ld.add(obj2) ld.clear() self.assertSetEqual(self.signal, {obj2, obj1}) self.assertEqual(ld, dict()) def test_delitem_nocallback(self): ld = LabelledDict(label='all_objects', key_attr='prop1') obj1 = MyTestClass('a', 'b') ld.add(obj1) del ld['a'] self.assertEqual(ld, dict()) def test_delitem_callback(self): self.signal = None def func(v): self.signal = v ld = LabelledDict(label='all_objects', key_attr='prop1', remove_callable=func) obj1 = MyTestClass('a', 'b') ld.add(obj1) del ld['a'] self.assertEqual(self.signal, obj1) self.assertEqual(ld, dict()) def test_update_callback(self): ld = LabelledDict(label='all_objects', key_attr='prop1') with self.assertRaisesWith(TypeError, "update is not supported for LabelledDict"): ld.update(object()) def test_setdefault_callback(self): ld = LabelledDict(label='all_objects', key_attr='prop1') with self.assertRaisesWith(TypeError, "setdefault is not supported for LabelledDict"): ld.setdefault(object()) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/utils_test/test_utils.py0000644000655200065520000001560700000000000021663 0ustar00circlecicircleciimport os import h5py import numpy as np from hdmf.data_utils import DataChunkIterator, DataIO from hdmf.testing import TestCase from hdmf.utils import get_data_shape, to_uint_array class TestGetDataShape(TestCase): def test_h5dataset(self): """Test get_data_shape on h5py.Datasets of various shapes and maxshape.""" path = 'test_get_data_shape.h5' with h5py.File(path, 'w') as f: dset = f.create_dataset('data', data=((1, 2), (3, 4), (5, 6))) res = get_data_shape(dset) self.assertTupleEqual(res, (3, 2)) dset = f.create_dataset('shape', shape=(3, 2)) res = get_data_shape(dset) self.assertTupleEqual(res, (3, 2)) # test that maxshape takes priority dset = f.create_dataset('shape_maxshape', shape=(3, 2), maxshape=(None, 100)) res = get_data_shape(dset) self.assertTupleEqual(res, (None, 100)) os.remove(path) def test_dci(self): """Test get_data_shape on DataChunkIterators of various shapes and maxshape.""" dci = DataChunkIterator(dtype=np.dtype(int)) res = get_data_shape(dci) self.assertIsNone(res) dci = DataChunkIterator(data=[1, 2]) res = get_data_shape(dci) self.assertTupleEqual(res, (2, )) dci = DataChunkIterator(data=[[1, 2], [3, 4], [5, 6]]) res = get_data_shape(dci) self.assertTupleEqual(res, (3, 2)) # test that maxshape takes priority dci = DataChunkIterator(data=[[1, 2], [3, 4], [5, 6]], maxshape=(None, 100)) res = get_data_shape(dci) self.assertTupleEqual(res, (None, 100)) def test_dataio(self): """Test get_data_shape on DataIO of various shapes and maxshape.""" dio = DataIO(data=[1, 2]) res = get_data_shape(dio) self.assertTupleEqual(res, (2, )) dio = DataIO(data=[[1, 2], [3, 4], [5, 6]]) res = get_data_shape(dio) self.assertTupleEqual(res, (3, 2)) dio = DataIO(data=np.array([[1, 2], [3, 4], [5, 6]])) res = get_data_shape(dio) self.assertTupleEqual(res, (3, 2)) def test_list(self): """Test get_data_shape on lists of various shapes.""" res = get_data_shape(list()) self.assertTupleEqual(res, (0, )) res = get_data_shape([1, 2]) self.assertTupleEqual(res, (2, )) res = get_data_shape([[1, 2], [3, 4], [5, 6]]) self.assertTupleEqual(res, (3, 2)) def test_tuple(self): """Test get_data_shape on tuples of various shapes.""" res = get_data_shape(tuple()) self.assertTupleEqual(res, (0, )) res = get_data_shape((1, 2)) self.assertTupleEqual(res, (2, )) res = get_data_shape(((1, 2), (3, 4), (5, 6))) self.assertTupleEqual(res, (3, 2)) def test_nparray(self): """Test get_data_shape on numpy arrays of various shapes.""" res = get_data_shape(np.empty([])) self.assertTupleEqual(res, tuple()) res = get_data_shape(np.array([])) self.assertTupleEqual(res, (0, )) res = get_data_shape(np.array([1, 2])) self.assertTupleEqual(res, (2, )) res = get_data_shape(np.array([[1, 2], [3, 4], [5, 6]])) self.assertTupleEqual(res, (3, 2)) def test_other(self): """Test get_data_shape on miscellaneous edge cases.""" res = get_data_shape(dict()) self.assertIsNone(res) res = get_data_shape(None) self.assertIsNone(res) res = get_data_shape([None, None]) self.assertTupleEqual(res, (2, )) res = get_data_shape(object()) self.assertIsNone(res) res = get_data_shape([object(), object()]) self.assertTupleEqual(res, (2, )) def test_string(self): """Test get_data_shape on strings and collections of strings.""" res = get_data_shape('abc') self.assertIsNone(res) res = get_data_shape(('a', 'b')) self.assertTupleEqual(res, (2, )) res = get_data_shape((('a', 'b'), ('c', 'd'), ('e', 'f'))) self.assertTupleEqual(res, (3, 2)) def test_set(self): """Test get_data_shape on sets, which have __len__ but are not subscriptable.""" res = get_data_shape(set()) self.assertTupleEqual(res, (0, )) res = get_data_shape({1, 2}) self.assertTupleEqual(res, (2, )) def test_arbitrary_iterable_with_len(self): """Test get_data_shape with strict_no_data_load=True on an arbitrary iterable object with __len__.""" class MyIterable: """Iterable class without shape or maxshape, where loading the first element raises an error.""" def __len__(self): return 10 def __iter__(self): return self def __next__(self): raise DataLoadedError() class DataLoadedError(Exception): pass data = MyIterable() with self.assertRaises(DataLoadedError): get_data_shape(data) # test that data is loaded res = get_data_shape(data, strict_no_data_load=True) # no error raised means data was not loaded self.assertIsNone(res) def test_strict_no_data_load(self): """Test get_data_shape with strict_no_data_load=True on nested lists/tuples is the same as when it is False.""" res = get_data_shape([[1, 2], [3, 4], [5, 6]], strict_no_data_load=True) self.assertTupleEqual(res, (3, 2)) res = get_data_shape(((1, 2), (3, 4), (5, 6)), strict_no_data_load=True) self.assertTupleEqual(res, (3, 2)) class TestToUintArray(TestCase): def test_ndarray_uint(self): arr = np.array([0, 1, 2], dtype=np.uint32) res = to_uint_array(arr) np.testing.assert_array_equal(res, arr) def test_ndarray_int(self): arr = np.array([0, 1, 2], dtype=np.int32) res = to_uint_array(arr) np.testing.assert_array_equal(res, arr) def test_ndarray_int_neg(self): arr = np.array([0, -1, 2], dtype=np.int32) with self.assertRaisesWith(ValueError, 'Cannot convert negative integer values to uint.'): to_uint_array(arr) def test_ndarray_float(self): arr = np.array([0, 1, 2], dtype=np.float64) with self.assertRaisesWith(ValueError, 'Cannot convert array of dtype float64 to uint.'): to_uint_array(arr) def test_list_int(self): arr = [0, 1, 2] res = to_uint_array(arr) expected = np.array([0, 1, 2], dtype=np.uint32) np.testing.assert_array_equal(res, expected) def test_list_int_neg(self): arr = [0, -1, 2] with self.assertRaisesWith(ValueError, 'Cannot convert negative integer values to uint.'): to_uint_array(arr) def test_list_float(self): arr = [0., 1., 2.] with self.assertRaisesWith(ValueError, 'Cannot convert array of dtype float64 to uint.'): to_uint_array(arr) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1627603655.1886272 hdmf-3.1.1/tests/unit/validator_tests/0000755000655200065520000000000000000000000020111 5ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/validator_tests/__init__.py0000644000655200065520000000000000000000000022210 0ustar00circlecicircleci././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/validator_tests/test_errors.py0000644000655200065520000000423300000000000023040 0ustar00circlecicirclecifrom unittest import TestCase from hdmf.validate.errors import Error class TestErrorEquality(TestCase): def test_self_equality(self): """Verify that one error equals itself""" error = Error('foo', 'bad thing', 'a.b.c') self.assertEqual(error, error) def test_equality_with_same_field_values(self): """Verify that two errors with the same field values are equal""" err1 = Error('foo', 'bad thing', 'a.b.c') err2 = Error('foo', 'bad thing', 'a.b.c') self.assertEqual(err1, err2) def test_not_equal_with_different_reason(self): """Verify that two errors with a different reason are not equal""" err1 = Error('foo', 'bad thing', 'a.b.c') err2 = Error('foo', 'something else', 'a.b.c') self.assertNotEqual(err1, err2) def test_not_equal_with_different_name(self): """Verify that two errors with a different name are not equal""" err1 = Error('foo', 'bad thing', 'a.b.c') err2 = Error('bar', 'bad thing', 'a.b.c') self.assertNotEqual(err1, err2) def test_not_equal_with_different_location(self): """Verify that two errors with a different location are not equal""" err1 = Error('foo', 'bad thing', 'a.b.c') err2 = Error('foo', 'bad thing', 'd.e.f') self.assertNotEqual(err1, err2) def test_equal_with_no_location(self): """Verify that two errors with no location but the same name are equal""" err1 = Error('foo', 'bad thing') err2 = Error('foo', 'bad thing') self.assertEqual(err1, err2) def test_not_equal_with_overlapping_name_when_no_location(self): """Verify that two errors with an overlapping name but no location are not equal """ err1 = Error('foo', 'bad thing') err2 = Error('x/y/foo', 'bad thing') self.assertNotEqual(err1, err2) def test_equal_with_overlapping_name_when_location_present(self): """Verify that two errors with an overlapping name and a location are equal""" err1 = Error('foo', 'bad thing', 'a.b.c') err2 = Error('x/y/foo', 'bad thing', 'a.b.c') self.assertEqual(err1, err2) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tests/unit/validator_tests/test_validate.py0000644000655200065520000014736700000000000023335 0ustar00circlecicirclecifrom abc import ABCMeta, abstractmethod from datetime import datetime from unittest import mock, skip import numpy as np from dateutil.tz import tzlocal from hdmf.build import GroupBuilder, DatasetBuilder, LinkBuilder, ReferenceBuilder, TypeMap, BuildManager from hdmf.spec import (GroupSpec, AttributeSpec, DatasetSpec, SpecCatalog, SpecNamespace, LinkSpec, RefSpec, NamespaceCatalog, DtypeSpec) from hdmf.spec.spec import ONE_OR_MANY, ZERO_OR_MANY, ZERO_OR_ONE from hdmf.testing import TestCase, remove_test_file from hdmf.validate import ValidatorMap from hdmf.validate.errors import (DtypeError, MissingError, ExpectedArrayError, MissingDataType, IncorrectQuantityError, IllegalLinkError) from hdmf.backends.hdf5 import HDF5IO CORE_NAMESPACE = 'test_core' class ValidatorTestBase(TestCase, metaclass=ABCMeta): def setUp(self): spec_catalog = SpecCatalog() for spec in self.getSpecs(): spec_catalog.register_spec(spec, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) self.vmap = ValidatorMap(self.namespace) @abstractmethod def getSpecs(self): pass def assertValidationError(self, error, type_, name=None, reason=None): """Assert that a validation Error matches expectations""" self.assertIsInstance(error, type_) if name is not None: self.assertEqual(error.name, name) if reason is not None: self.assertEqual(error.reason, reason) class TestEmptySpec(ValidatorTestBase): def getSpecs(self): return (GroupSpec('A test group specification with a data type', data_type_def='Bar'),) def test_valid(self): builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar'}) validator = self.vmap.get_validator('Bar') result = validator.validate(builder) self.assertEqual(len(result), 0) def test_invalid_missing_req_type(self): builder = GroupBuilder('my_bar') err_msg = r"builder must have data type defined with attribute '[A-Za-z_]+'" with self.assertRaisesRegex(ValueError, err_msg): self.vmap.validate(builder) class TestBasicSpec(ValidatorTestBase): def getSpecs(self): ret = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', 'int', name='data', attributes=[AttributeSpec( 'attr2', 'an example integer attribute', 'int')])], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) return (ret,) def test_invalid_missing(self): builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar'}) validator = self.vmap.get_validator('Bar') result = validator.validate(builder) self.assertEqual(len(result), 2) self.assertValidationError(result[0], MissingError, name='Bar/attr1') self.assertValidationError(result[1], MissingError, name='Bar/data') def test_invalid_incorrect_type_get_validator(self): builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 10}) validator = self.vmap.get_validator('Bar') result = validator.validate(builder) self.assertEqual(len(result), 2) self.assertValidationError(result[0], DtypeError, name='Bar/attr1') self.assertValidationError(result[1], MissingError, name='Bar/data') def test_invalid_incorrect_type_validate(self): builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 10}) result = self.vmap.validate(builder) self.assertEqual(len(result), 2) self.assertValidationError(result[0], DtypeError, name='Bar/attr1') self.assertValidationError(result[1], MissingError, name='Bar/data') def test_valid(self): builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10})]) validator = self.vmap.get_validator('Bar') result = validator.validate(builder) self.assertEqual(len(result), 0) class TestDateTimeInSpec(ValidatorTestBase): def getSpecs(self): ret = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', 'int', name='data', attributes=[AttributeSpec( 'attr2', 'an example integer attribute', 'int')]), DatasetSpec('an example time dataset', 'isodatetime', name='time'), DatasetSpec('an array of times', 'isodatetime', name='time_array', dims=('num_times',), shape=(None,))], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) return (ret,) def test_valid_isodatetime(self): builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10}), DatasetBuilder('time', datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())), DatasetBuilder('time_array', [datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())])]) validator = self.vmap.get_validator('Bar') result = validator.validate(builder) self.assertEqual(len(result), 0) def test_invalid_isodatetime(self): builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10}), DatasetBuilder('time', 100), DatasetBuilder('time_array', [datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())])]) validator = self.vmap.get_validator('Bar') result = validator.validate(builder) self.assertEqual(len(result), 1) self.assertValidationError(result[0], DtypeError, name='Bar/time') def test_invalid_isodatetime_array(self): builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10}), DatasetBuilder('time', datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal())), DatasetBuilder('time_array', datetime(2017, 5, 1, 12, 0, 0, tzinfo=tzlocal()))]) validator = self.vmap.get_validator('Bar') result = validator.validate(builder) self.assertEqual(len(result), 1) self.assertValidationError(result[0], ExpectedArrayError, name='Bar/time_array') class TestNestedTypes(ValidatorTestBase): def getSpecs(self): baz = DatasetSpec('A dataset with a data type', 'int', data_type_def='Baz', attributes=[AttributeSpec('attr2', 'an example integer attribute', 'int')]) bar = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', data_type_inc='Baz')], attributes=[AttributeSpec('attr1', 'an example string attribute', 'text')]) foo = GroupSpec('A test group that contains a data type', data_type_def='Foo', groups=[GroupSpec('A Bar group for Foos', name='my_bar', data_type_inc='Bar')], attributes=[AttributeSpec('foo_attr', 'a string attribute specified as text', 'text', required=False)]) return (bar, foo, baz) def test_invalid_missing_named_req_group(self): """Test that a MissingDataType is returned when a required named nested data type is missing.""" foo_builder = GroupBuilder('my_foo', attributes={'data_type': 'Foo', 'foo_attr': 'example Foo object'}) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 1) self.assertValidationError(results[0], MissingDataType, name='Foo', reason='missing data type Bar (my_bar)') def test_invalid_wrong_name_req_type(self): """Test that a MissingDataType is returned when a required nested data type is given the wrong name.""" bar_builder = GroupBuilder('bad_bar_name', attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, datasets=[DatasetBuilder('data', 100, attributes={'attr2': 10})]) foo_builder = GroupBuilder('my_foo', attributes={'data_type': 'Foo', 'foo_attr': 'example Foo object'}, groups=[bar_builder]) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 1) self.assertValidationError(results[0], MissingDataType, name='Foo') self.assertEqual(results[0].data_type, 'Bar') def test_invalid_missing_unnamed_req_group(self): """Test that a MissingDataType is returned when a required unnamed nested data type is missing.""" bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}) foo_builder = GroupBuilder('my_foo', attributes={'data_type': 'Foo', 'foo_attr': 'example Foo object'}, groups=[bar_builder]) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 1) self.assertValidationError(results[0], MissingDataType, name='Bar', reason='missing data type Baz') def test_valid(self): """Test that no errors are returned when nested data types are correctly built.""" bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, datasets=[DatasetBuilder('data', 100, attributes={'data_type': 'Baz', 'attr2': 10})]) foo_builder = GroupBuilder('my_foo', attributes={'data_type': 'Foo', 'foo_attr': 'example Foo object'}, groups=[bar_builder]) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 0) def test_valid_wo_opt_attr(self): """"Test that no errors are returned when an optional attribute is omitted from a group.""" bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': 'a string attribute'}, datasets=[DatasetBuilder('data', 100, attributes={'data_type': 'Baz', 'attr2': 10})]) foo_builder = GroupBuilder('my_foo', attributes={'data_type': 'Foo'}, groups=[bar_builder]) results = self.vmap.validate(foo_builder) self.assertEqual(len(results), 0) class TestQuantityValidation(TestCase): def create_test_specs(self, q_groups, q_datasets, q_links): bar = GroupSpec('A test group', data_type_def='Bar') baz = DatasetSpec('A test dataset', 'int', data_type_def='Baz') qux = GroupSpec('A group to link', data_type_def='Qux') foo = GroupSpec('A group containing a quantity of tests and datasets', data_type_def='Foo', groups=[GroupSpec('A bar', data_type_inc='Bar', quantity=q_groups)], datasets=[DatasetSpec('A baz', data_type_inc='Baz', quantity=q_datasets)], links=[LinkSpec('A qux', target_type='Qux', quantity=q_links)],) return (bar, foo, baz, qux) def configure_specs(self, specs): spec_catalog = SpecCatalog() for spec in specs: spec_catalog.register_spec(spec, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) self.vmap = ValidatorMap(self.namespace) def get_test_builder(self, n_groups, n_datasets, n_links): child_groups = [GroupBuilder(f'bar_{n}', attributes={'data_type': 'Bar'}) for n in range(n_groups)] child_datasets = [DatasetBuilder(f'baz_{n}', n, attributes={'data_type': 'Baz'}) for n in range(n_datasets)] child_links = [LinkBuilder(GroupBuilder(f'qux_{n}', attributes={'data_type': 'Qux'}), f'qux_{n}_link') for n in range(n_links)] return GroupBuilder('my_foo', attributes={'data_type': 'Foo'}, groups=child_groups, datasets=child_datasets, links=child_links) def test_valid_zero_or_many(self): """"Verify that groups/datasets/links with ZERO_OR_MANY and a valid quantity correctly pass validation""" specs = self.create_test_specs(q_groups=ZERO_OR_MANY, q_datasets=ZERO_OR_MANY, q_links=ZERO_OR_MANY) self.configure_specs(specs) for n in [0, 1, 2, 5]: with self.subTest(quantity=n): builder = self.get_test_builder(n_groups=n, n_datasets=n, n_links=n) results = self.vmap.validate(builder) self.assertEqual(len(results), 0) def test_valid_one_or_many(self): """"Verify that groups/datasets/links with ONE_OR_MANY and a valid quantity correctly pass validation""" specs = self.create_test_specs(q_groups=ONE_OR_MANY, q_datasets=ONE_OR_MANY, q_links=ONE_OR_MANY) self.configure_specs(specs) for n in [1, 2, 5]: with self.subTest(quantity=n): builder = self.get_test_builder(n_groups=n, n_datasets=n, n_links=n) results = self.vmap.validate(builder) self.assertEqual(len(results), 0) def test_valid_zero_or_one(self): """"Verify that groups/datasets/links with ZERO_OR_ONE and a valid quantity correctly pass validation""" specs = self.create_test_specs(q_groups=ZERO_OR_ONE, q_datasets=ZERO_OR_ONE, q_links=ZERO_OR_ONE) self.configure_specs(specs) for n in [0, 1]: with self.subTest(quantity=n): builder = self.get_test_builder(n_groups=n, n_datasets=n, n_links=n) results = self.vmap.validate(builder) self.assertEqual(len(results), 0) def test_valid_fixed_quantity(self): """"Verify that groups/datasets/links with a correct fixed quantity correctly pass validation""" self.configure_specs(self.create_test_specs(q_groups=2, q_datasets=3, q_links=5)) builder = self.get_test_builder(n_groups=2, n_datasets=3, n_links=5) results = self.vmap.validate(builder) self.assertEqual(len(results), 0) def test_missing_one_or_many_should_not_return_incorrect_quantity_error(self): """Verify that missing ONE_OR_MANY groups/datasets/links should not return an IncorrectQuantityError NOTE: a MissingDataType error should be returned instead """ specs = self.create_test_specs(q_groups=ONE_OR_MANY, q_datasets=ONE_OR_MANY, q_links=ONE_OR_MANY) self.configure_specs(specs) builder = self.get_test_builder(n_groups=0, n_datasets=0, n_links=0) results = self.vmap.validate(builder) self.assertFalse(any(isinstance(e, IncorrectQuantityError) for e in results)) def test_missing_fixed_quantity_should_not_return_incorrect_quantity_error(self): """Verify that missing groups/datasets/links should not return an IncorrectQuantityError""" self.configure_specs(self.create_test_specs(q_groups=5, q_datasets=3, q_links=2)) builder = self.get_test_builder(0, 0, 0) results = self.vmap.validate(builder) self.assertFalse(any(isinstance(e, IncorrectQuantityError) for e in results)) def test_incorrect_fixed_quantity_should_return_incorrect_quantity_error(self): """Verify that an incorrect quantity of groups/datasets/links should return an IncorrectQuantityError""" self.configure_specs(self.create_test_specs(q_groups=5, q_datasets=5, q_links=5)) for n in [1, 2, 10]: with self.subTest(quantity=n): builder = self.get_test_builder(n_groups=n, n_datasets=n, n_links=n) results = self.vmap.validate(builder) self.assertEqual(len(results), 3) self.assertTrue(all(isinstance(e, IncorrectQuantityError) for e in results)) def test_incorrect_zero_or_one_quantity_should_return_incorrect_quantity_error(self): """Verify that an incorrect ZERO_OR_ONE quantity of groups/datasets/links should return an IncorrectQuantityError """ specs = self.create_test_specs(q_groups=ZERO_OR_ONE, q_datasets=ZERO_OR_ONE, q_links=ZERO_OR_ONE) self.configure_specs(specs) builder = self.get_test_builder(n_groups=2, n_datasets=2, n_links=2) results = self.vmap.validate(builder) self.assertEqual(len(results), 3) self.assertTrue(all(isinstance(e, IncorrectQuantityError) for e in results)) def test_incorrect_quantity_error_message(self): """Verify that an IncorrectQuantityError includes the expected information in the message""" specs = self.create_test_specs(q_groups=2, q_datasets=ZERO_OR_MANY, q_links=ZERO_OR_MANY) self.configure_specs(specs) builder = self.get_test_builder(n_groups=7, n_datasets=0, n_links=0) results = self.vmap.validate(builder) self.assertEqual(len(results), 1) self.assertIsInstance(results[0], IncorrectQuantityError) message = str(results[0]) self.assertTrue('expected a quantity of 2' in message) self.assertTrue('received 7' in message) class TestDtypeValidation(TestCase): def set_up_spec(self, dtype): spec_catalog = SpecCatalog() spec = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', dtype, name='data')], attributes=[AttributeSpec('attr1', 'an example attribute', dtype)]) spec_catalog.register_spec(spec, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) self.vmap = ValidatorMap(self.namespace) def test_ascii_for_utf8(self): """Test that validator allows ASCII data where UTF8 is specified.""" self.set_up_spec('text') value = b'an ascii string' bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) def test_utf8_for_ascii(self): """Test that validator does not allow UTF8 where ASCII is specified.""" self.set_up_spec('bytes') value = 'a utf8 string' bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) result_strings = set([str(s) for s in results]) expected_errors = {"Bar/attr1 (my_bar.attr1): incorrect type - expected 'bytes', got 'utf'", "Bar/data (my_bar/data): incorrect type - expected 'bytes', got 'utf'"} self.assertEqual(result_strings, expected_errors) def test_int64_for_int8(self): """Test that validator allows int64 data where int8 is specified.""" self.set_up_spec('int8') value = np.int64(1) bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) def test_int8_for_int64(self): """Test that validator does not allow int8 data where int64 is specified.""" self.set_up_spec('int64') value = np.int8(1) bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) result_strings = set([str(s) for s in results]) expected_errors = {"Bar/attr1 (my_bar.attr1): incorrect type - expected 'int64', got 'int8'", "Bar/data (my_bar/data): incorrect type - expected 'int64', got 'int8'"} self.assertEqual(result_strings, expected_errors) def test_int64_for_numeric(self): """Test that validator allows int64 data where numeric is specified.""" self.set_up_spec('numeric') value = np.int64(1) bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) def test_bool_for_numeric(self): """Test that validator does not allow bool data where numeric is specified.""" self.set_up_spec('numeric') value = True bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) result_strings = set([str(s) for s in results]) expected_errors = {"Bar/attr1 (my_bar.attr1): incorrect type - expected 'numeric', got 'bool'", "Bar/data (my_bar/data): incorrect type - expected 'numeric', got 'bool'"} self.assertEqual(result_strings, expected_errors) def test_np_bool_for_bool(self): """Test that validator allows np.bool_ data where bool is specified.""" self.set_up_spec('bool') value = np.bool_(True) bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) class Test1DArrayValidation(TestCase): def set_up_spec(self, dtype): spec_catalog = SpecCatalog() spec = GroupSpec('A test group specification with a data type', data_type_def='Bar', datasets=[DatasetSpec('an example dataset', dtype, name='data', shape=(None, ))], attributes=[AttributeSpec('attr1', 'an example attribute', dtype, shape=(None, ))]) spec_catalog.register_spec(spec, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) self.vmap = ValidatorMap(self.namespace) def test_scalar(self): """Test that validator does not allow a scalar where an array is specified.""" self.set_up_spec('text') value = 'a string' bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) result_strings = set([str(s) for s in results]) expected_errors = {("Bar/attr1 (my_bar.attr1): incorrect shape - expected an array of shape '(None,)', " "got non-array data 'a string'"), ("Bar/data (my_bar/data): incorrect shape - expected an array of shape '(None,)', " "got non-array data 'a string'")} self.assertEqual(result_strings, expected_errors) def test_empty_list(self): """Test that validator allows an empty list where an array is specified.""" self.set_up_spec('text') value = [] bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) def test_empty_nparray(self): """Test that validator allows an empty numpy array where an array is specified.""" self.set_up_spec('text') value = np.array([]) # note: dtype is float64 bar_builder = GroupBuilder('my_bar', attributes={'data_type': 'Bar', 'attr1': value}, datasets=[DatasetBuilder('data', value)]) results = self.vmap.validate(bar_builder) self.assertEqual(len(results), 0) # TODO test shape validation more completely class TestLinkable(TestCase): def set_up_spec(self): spec_catalog = SpecCatalog() typed_dataset_spec = DatasetSpec('A typed dataset', data_type_def='Foo') typed_group_spec = GroupSpec('A typed group', data_type_def='Bar') spec = GroupSpec('A test group specification with a data type', data_type_def='Baz', datasets=[ DatasetSpec('A linkable child dataset', name='untyped_linkable_ds', linkable=True, quantity=ZERO_OR_ONE), DatasetSpec('A non-linkable child dataset', name='untyped_nonlinkable_ds', linkable=False, quantity=ZERO_OR_ONE), DatasetSpec('A linkable child dataset', data_type_inc='Foo', name='typed_linkable_ds', linkable=True, quantity=ZERO_OR_ONE), DatasetSpec('A non-linkable child dataset', data_type_inc='Foo', name='typed_nonlinkable_ds', linkable=False, quantity=ZERO_OR_ONE), ], groups=[ GroupSpec('A linkable child group', name='untyped_linkable_group', linkable=True, quantity=ZERO_OR_ONE), GroupSpec('A non-linkable child group', name='untyped_nonlinkable_group', linkable=False, quantity=ZERO_OR_ONE), GroupSpec('A linkable child group', data_type_inc='Bar', name='typed_linkable_group', linkable=True, quantity=ZERO_OR_ONE), GroupSpec('A non-linkable child group', data_type_inc='Bar', name='typed_nonlinkable_group', linkable=False, quantity=ZERO_OR_ONE), ]) spec_catalog.register_spec(spec, 'test.yaml') spec_catalog.register_spec(typed_dataset_spec, 'test.yaml') spec_catalog.register_spec(typed_group_spec, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) self.vmap = ValidatorMap(self.namespace) def validate_linkability(self, link, expect_error): """Execute a linkability test and assert whether or not an IllegalLinkError is returned""" self.set_up_spec() builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, links=[link]) result = self.vmap.validate(builder) if expect_error: self.assertEqual(len(result), 1) self.assertIsInstance(result[0], IllegalLinkError) else: self.assertEqual(len(result), 0) def test_untyped_linkable_dataset_accepts_link(self): """Test that the validator accepts a link when the spec has an untyped linkable dataset""" link = LinkBuilder(name='untyped_linkable_ds', builder=DatasetBuilder('foo')) self.validate_linkability(link, expect_error=False) def test_untyped_nonlinkable_dataset_does_not_accept_link(self): """Test that the validator returns an IllegalLinkError when the spec has an untyped non-linkable dataset""" link = LinkBuilder(name='untyped_nonlinkable_ds', builder=DatasetBuilder('foo')) self.validate_linkability(link, expect_error=True) def test_typed_linkable_dataset_accepts_link(self): """Test that the validator accepts a link when the spec has a typed linkable dataset""" link = LinkBuilder(name='typed_linkable_ds', builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})) self.validate_linkability(link, expect_error=False) def test_typed_nonlinkable_dataset_does_not_accept_link(self): """Test that the validator returns an IllegalLinkError when the spec has a typed non-linkable dataset""" link = LinkBuilder(name='typed_nonlinkable_ds', builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})) self.validate_linkability(link, expect_error=True) def test_untyped_linkable_group_accepts_link(self): """Test that the validator accepts a link when the spec has an untyped linkable group""" link = LinkBuilder(name='untyped_linkable_group', builder=GroupBuilder('foo')) self.validate_linkability(link, expect_error=False) def test_untyped_nonlinkable_group_does_not_accept_link(self): """Test that the validator returns an IllegalLinkError when the spec has an untyped non-linkable group""" link = LinkBuilder(name='untyped_nonlinkable_group', builder=GroupBuilder('foo')) self.validate_linkability(link, expect_error=True) def test_typed_linkable_group_accepts_link(self): """Test that the validator accepts a link when the spec has a typed linkable group""" link = LinkBuilder(name='typed_linkable_group', builder=GroupBuilder('foo', attributes={'data_type': 'Bar'})) self.validate_linkability(link, expect_error=False) def test_typed_nonlinkable_group_does_not_accept_link(self): """Test that the validator returns an IllegalLinkError when the spec has a typed non-linkable group""" link = LinkBuilder(name='typed_nonlinkable_group', builder=GroupBuilder('foo', attributes={'data_type': 'Bar'})) self.validate_linkability(link, expect_error=True) @mock.patch("hdmf.validate.validator.DatasetValidator.validate") def test_should_not_validate_illegally_linked_objects(self, mock_validator): """Test that an illegally linked child dataset is not validated Note: this behavior is expected to change in the future: https://github.com/hdmf-dev/hdmf/issues/516 """ self.set_up_spec() typed_link = LinkBuilder(name='typed_nonlinkable_ds', builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})) untyped_link = LinkBuilder(name='untyped_nonlinkable_ds', builder=DatasetBuilder('foo')) builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, links=[typed_link, untyped_link]) _ = self.vmap.validate(builder) assert not mock_validator.called class TestMultipleNamedChildrenOfSameType(TestCase): """When a group has multiple named children of the same type (such as X, Y, and Z VectorData), they all need to be validated. """ def set_up_spec(self): spec_catalog = SpecCatalog() dataset_spec = DatasetSpec('A dataset', data_type_def='Foo') group_spec = GroupSpec('A group', data_type_def='Bar') spec = GroupSpec('A test group specification with a data type', data_type_def='Baz', datasets=[ DatasetSpec('Child Dataset A', name='a', data_type_inc='Foo'), DatasetSpec('Child Dataset B', name='b', data_type_inc='Foo'), ], groups=[ GroupSpec('Child Group X', name='x', data_type_inc='Bar'), GroupSpec('Child Group Y', name='y', data_type_inc='Bar'), ]) spec_catalog.register_spec(spec, 'test.yaml') spec_catalog.register_spec(dataset_spec, 'test.yaml') spec_catalog.register_spec(group_spec, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) self.vmap = ValidatorMap(self.namespace) def validate_multiple_children(self, dataset_names, group_names): """Utility function to validate a builder with the specified named dataset and group children""" self.set_up_spec() datasets = [DatasetBuilder(ds, attributes={'data_type': 'Foo'}) for ds in dataset_names] groups = [GroupBuilder(gr, attributes={'data_type': 'Bar'}) for gr in group_names] builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets, groups=groups) return self.vmap.validate(builder) def test_missing_first_dataset_should_return_error(self): """Test that the validator returns a MissingDataType error if the first dataset is missing""" result = self.validate_multiple_children(['b'], ['x', 'y']) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_missing_last_dataset_should_return_error(self): """Test that the validator returns a MissingDataType error if the last dataset is missing""" result = self.validate_multiple_children(['a'], ['x', 'y']) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_missing_first_group_should_return_error(self): """Test that the validator returns a MissingDataType error if the first group is missing""" result = self.validate_multiple_children(['a', 'b'], ['y']) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_missing_last_group_should_return_error(self): """Test that the validator returns a MissingDataType error if the last group is missing""" result = self.validate_multiple_children(['a', 'b'], ['x']) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_no_errors_when_all_children_satisfied(self): """Test that the validator does not return an error if all child specs are satisfied""" result = self.validate_multiple_children(['a', 'b'], ['x', 'y']) self.assertEqual(len(result), 0) class TestLinkAndChildMatchingDataType(TestCase): """If a link and a child dataset/group have the same specified data type, both the link and the child need to be validated """ def set_up_spec(self): spec_catalog = SpecCatalog() dataset_spec = DatasetSpec('A dataset', data_type_def='Foo') group_spec = GroupSpec('A group', data_type_def='Bar') spec = GroupSpec('A test group specification with a data type', data_type_def='Baz', datasets=[ DatasetSpec('Child Dataset', name='dataset', data_type_inc='Foo'), ], groups=[ GroupSpec('Child Group', name='group', data_type_inc='Bar'), ], links=[ LinkSpec('Linked Dataset', name='dataset_link', target_type='Foo'), LinkSpec('Linked Dataset', name='group_link', target_type='Bar') ]) spec_catalog.register_spec(spec, 'test.yaml') spec_catalog.register_spec(dataset_spec, 'test.yaml') spec_catalog.register_spec(group_spec, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) self.vmap = ValidatorMap(self.namespace) def validate_matching_link_data_type_case(self, datasets, groups, links): """Execute validation against a group builder using the provided group children and verify that a MissingDataType error is returned """ self.set_up_spec() builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets, groups=groups, links=links) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_error_on_missing_child_dataset(self): """Test that a MissingDataType is returned when the child dataset is missing""" datasets = [] groups = [GroupBuilder('group', attributes={'data_type': 'Bar'})] links = [ LinkBuilder(name='dataset_link', builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})), LinkBuilder(name='group_link', builder=GroupBuilder('bar', attributes={'data_type': 'Bar'})) ] self.validate_matching_link_data_type_case(datasets, groups, links) def test_error_on_missing_linked_dataset(self): """Test that a MissingDataType is returned when the linked dataset is missing""" datasets = [DatasetBuilder('dataset', attributes={'data_type': 'Foo'})] groups = [GroupBuilder('group', attributes={'data_type': 'Bar'})] links = [ LinkBuilder(name='group_link', builder=GroupBuilder('bar', attributes={'data_type': 'Bar'})) ] self.validate_matching_link_data_type_case(datasets, groups, links) def test_error_on_missing_group(self): """Test that a MissingDataType is returned when the child group is missing""" self.set_up_spec() datasets = [DatasetBuilder('dataset', attributes={'data_type': 'Foo'})] groups = [] links = [ LinkBuilder(name='dataset_link', builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})), LinkBuilder(name='group_link', builder=GroupBuilder('bar', attributes={'data_type': 'Bar'})) ] self.validate_matching_link_data_type_case(datasets, groups, links) def test_error_on_missing_linked_group(self): """Test that a MissingDataType is returned when the linked group is missing""" self.set_up_spec() datasets = [DatasetBuilder('dataset', attributes={'data_type': 'Foo'})] groups = [GroupBuilder('group', attributes={'data_type': 'Bar'})] links = [ LinkBuilder(name='dataset_link', builder=DatasetBuilder('foo', attributes={'data_type': 'Foo'})) ] self.validate_matching_link_data_type_case(datasets, groups, links) class TestMultipleChildrenAtDifferentLevelsOfInheritance(TestCase): """When multiple children can satisfy multiple specs due to data_type inheritance, the validation needs to carefully match builders against specs """ def set_up_spec(self): spec_catalog = SpecCatalog() dataset_spec = DatasetSpec('A dataset', data_type_def='Foo') sub_dataset_spec = DatasetSpec('An Inheriting Dataset', data_type_def='Bar', data_type_inc='Foo') spec = GroupSpec('A test group specification with a data type', data_type_def='Baz', datasets=[ DatasetSpec('Child Dataset', data_type_inc='Foo'), DatasetSpec('Child Dataset', data_type_inc='Bar'), ]) spec_catalog.register_spec(spec, 'test.yaml') spec_catalog.register_spec(dataset_spec, 'test.yaml') spec_catalog.register_spec(sub_dataset_spec, 'test.yaml') self.namespace = SpecNamespace( 'a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) self.vmap = ValidatorMap(self.namespace) def test_error_returned_when_child_at_highest_level_missing(self): """Test that a MissingDataType error is returned when the dataset at the highest level of the inheritance hierarchy is missing """ self.set_up_spec() datasets = [ DatasetBuilder('bar', attributes={'data_type': 'Bar'}) ] builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_error_returned_when_child_at_lowest_level_missing(self): """Test that a MissingDataType error is returned when the dataset at the lowest level of the inheritance hierarchy is missing """ self.set_up_spec() datasets = [ DatasetBuilder('foo', attributes={'data_type': 'Foo'}) ] builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) self.assertIsInstance(result[0], MissingDataType) def test_both_levels_of_hierarchy_validated(self): """Test that when both required children at separate levels of inheritance hierarchy are present, both child specs are satisfied """ self.set_up_spec() datasets = [ DatasetBuilder('foo', attributes={'data_type': 'Foo'}), DatasetBuilder('bar', attributes={'data_type': 'Bar'}) ] builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets) result = self.vmap.validate(builder) self.assertEqual(len(result), 0) @skip("Functionality not yet supported") def test_both_levels_of_hierarchy_validated_inverted_order(self): """Test that when both required children at separate levels of inheritance hierarchy are present, both child specs are satisfied. This should work no matter what the order of the builders. """ self.set_up_spec() datasets = [ DatasetBuilder('bar', attributes={'data_type': 'Bar'}), DatasetBuilder('foo', attributes={'data_type': 'Foo'}) ] builder = GroupBuilder('my_baz', attributes={'data_type': 'Baz'}, datasets=datasets) result = self.vmap.validate(builder) self.assertEqual(len(result), 0) class TestExtendedIncDataTypes(TestCase): """Test validation against specs where a data type is included via data_type_inc and modified by adding new fields or constraining existing fields but is not defined as a new type via data_type_inc. For the purpose of this test class: we are calling a data type which is nested inside a group an "inner" data type. When an inner data type inherits from a data type via data_type_inc and has fields that are either added or modified from the base data type, we are labeling that data type as an "extension". When the inner data type extension does not define a new data type via data_type_def we say that it is an "anonymous extension". Anonymous data type extensions should be avoided in for new specs, but it does occur in existing nwb specs, so we need to allow and validate against it. One example is the `Units.spike_times` dataset attached to Units in the `core` nwb namespace, which extends `VectorData` via neurodata_type_inc but adds a new attribute named `resolution` without defining a new data type via neurodata_type_def. """ def setup_spec(self): """Prepare a set of specs for tests which includes an anonymous data type extension""" spec_catalog = SpecCatalog() attr_foo = AttributeSpec(name='foo', doc='an attribute', dtype='text') attr_bar = AttributeSpec(name='bar', doc='an attribute', dtype='numeric') d1_spec = DatasetSpec(doc='type D1', data_type_def='D1', dtype='numeric', attributes=[attr_foo]) d2_spec = DatasetSpec(doc='type D2', data_type_def='D2', data_type_inc=d1_spec) g1_spec = GroupSpec(doc='type G1', data_type_def='G1', datasets=[DatasetSpec(doc='D1 extension', data_type_inc=d1_spec, attributes=[attr_foo, attr_bar])]) for spec in [d1_spec, d2_spec, g1_spec]: spec_catalog.register_spec(spec, 'test.yaml') self.namespace = SpecNamespace('a test namespace', CORE_NAMESPACE, [{'source': 'test.yaml'}], version='0.1.0', catalog=spec_catalog) self.vmap = ValidatorMap(self.namespace) def test_missing_additional_attribute_on_anonymous_data_type_extension(self): """Verify that a MissingError is returned when a required attribute from an anonymous extension is not present """ self.setup_spec() dataset = DatasetBuilder('test_d1', 42.0, attributes={'data_type': 'D1', 'foo': 'xyz'}) builder = GroupBuilder('test_g1', attributes={'data_type': 'G1'}, datasets=[dataset]) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) error = result[0] self.assertIsInstance(error, MissingError) self.assertTrue('G1/D1/bar' in str(error)) def test_validate_child_type_against_anonymous_data_type_extension(self): """Verify that a MissingError is returned when a required attribute from an anonymous extension is not present on a data type which inherits from the data type included in the anonymous extension. """ self.setup_spec() dataset = DatasetBuilder('test_d2', 42.0, attributes={'data_type': 'D2', 'foo': 'xyz'}) builder = GroupBuilder('test_g1', attributes={'data_type': 'G1'}, datasets=[dataset]) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) error = result[0] self.assertIsInstance(error, MissingError) self.assertTrue('G1/D1/bar' in str(error)) def test_redundant_attribute_in_spec(self): """Test that only one MissingError is returned when an attribute is missing which is redundantly defined in both a base data type and an inner data type """ self.setup_spec() dataset = DatasetBuilder('test_d2', 42.0, attributes={'data_type': 'D2', 'bar': 5}) builder = GroupBuilder('test_g1', attributes={'data_type': 'G1'}, datasets=[dataset]) result = self.vmap.validate(builder) self.assertEqual(len(result), 1) class TestReferenceDatasetsRoundTrip(ValidatorTestBase): """Test that no errors occur when when datasets containing references either in an array or as part of a compound type are written out to file, read back in, and then validated. In order to support lazy reading on loading, datasets containing references are wrapped in lazy-loading ReferenceResolver objects. These tests verify that the validator can work with these ReferenceResolver objects. """ def setUp(self): self.filename = 'test_ref_dataset.h5' super().setUp() def tearDown(self): remove_test_file(self.filename) super().tearDown() def getSpecs(self): qux_spec = DatasetSpec( doc='a simple scalar dataset', data_type_def='Qux', dtype='int', shape=None ) baz_spec = DatasetSpec( doc='a dataset with a compound datatype that includes a reference', data_type_def='Baz', dtype=[ DtypeSpec('x', doc='x-value', dtype='int'), DtypeSpec('y', doc='y-ref', dtype=RefSpec('Qux', reftype='object')) ], shape=None ) bar_spec = DatasetSpec( doc='a dataset of an array of references', dtype=RefSpec('Qux', reftype='object'), data_type_def='Bar', shape=(None,) ) foo_spec = GroupSpec( doc='a base group for containing test datasets', data_type_def='Foo', datasets=[ DatasetSpec(doc='optional Bar', data_type_inc=bar_spec, quantity=ZERO_OR_ONE), DatasetSpec(doc='optional Baz', data_type_inc=baz_spec, quantity=ZERO_OR_ONE), DatasetSpec(doc='multiple qux', data_type_inc=qux_spec, quantity=ONE_OR_MANY) ] ) return (foo_spec, bar_spec, baz_spec, qux_spec) def runBuilderRoundTrip(self, builder): """Executes a round-trip test for a builder 1. First writes the builder to file, 2. next reads a new builder from disk 3. and finally runs the builder through the validator. The test is successful if there are no validation errors.""" ns_catalog = NamespaceCatalog() ns_catalog.add_namespace(self.namespace.name, self.namespace) typemap = TypeMap(ns_catalog) self.manager = BuildManager(typemap) with HDF5IO(self.filename, manager=self.manager, mode='w') as write_io: write_io.write_builder(builder) with HDF5IO(self.filename, manager=self.manager, mode='r') as read_io: read_builder = read_io.read_builder() errors = self.vmap.validate(read_builder) self.assertEqual(len(errors), 0, errors) def test_round_trip_validation_of_reference_dataset_array(self): """Verify that a dataset builder containing an array of references passes validation after a round trip""" qux1 = DatasetBuilder('q1', 5, attributes={'data_type': 'Qux'}) qux2 = DatasetBuilder('q2', 10, attributes={'data_type': 'Qux'}) bar = DatasetBuilder( name='bar', data=[ReferenceBuilder(qux1), ReferenceBuilder(qux2)], attributes={'data_type': 'Bar'}, dtype='object' ) foo = GroupBuilder( name='foo', datasets=[bar, qux1, qux2], attributes={'data_type': 'Foo'} ) self.runBuilderRoundTrip(foo) def test_round_trip_validation_of_compound_dtype_with_reference(self): """Verify that a dataset builder containing data with a compound dtype containing a reference passes validation after a round trip""" qux1 = DatasetBuilder('q1', 5, attributes={'data_type': 'Qux'}) qux2 = DatasetBuilder('q2', 10, attributes={'data_type': 'Qux'}) baz = DatasetBuilder( name='baz', data=[(10, ReferenceBuilder(qux1))], dtype=[ DtypeSpec('x', doc='x-value', dtype='int'), DtypeSpec('y', doc='y-ref', dtype=RefSpec('Qux', reftype='object')) ], attributes={'data_type': 'Baz'} ) foo = GroupBuilder( name='foo', datasets=[baz, qux1, qux2], attributes={'data_type': 'Foo'} ) self.runBuilderRoundTrip(foo) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/tox.ini0000644000655200065520000000752700000000000014107 0ustar00circlecicircleci# Tox (https://tox.readthedocs.io/) is a tool for running tests # in multiple virtualenvs. This configuration file will run the # test suite on all supported python versions. To use it, "pip install tox" # and then run "tox" from this directory. [tox] envlist = py37, py38, py39 [testenv] usedevelop = True setenv = PYTHONDONTWRITEBYTECODE = 1 install_command = pip install -U {opts} {packages} deps = -rrequirements-dev.txt -rrequirements.txt commands = pip check # Check for conflicting packages python test.py -v # Env to create coverage report locally [testenv:localcoverage] basepython = python3.9 commands = python -m coverage run test.py -u coverage html -d tests/coverage/htmlcov # Test with python 3.9, pinned dev reqs, and upgraded run requirements [testenv:py39-upgrade-dev] basepython = python3.9 install_command = pip install -U -e . {opts} {packages} deps = -rrequirements-dev.txt commands = {[testenv]commands} # Test with python 3.9, pinned dev reqs, and pre-release run requirements [testenv:py39-upgrade-dev-pre] basepython = python3.9 install_command = pip install -U --pre -e . {opts} {packages} deps = -rrequirements-dev.txt commands = {[testenv]commands} # Test with python 3.7, pinned dev reqs, and minimum run requirements [testenv:py37-min-req] basepython = python3.7 deps = -rrequirements-dev.txt -rrequirements-min.txt commands = {[testenv]commands} # Envs that builds wheels and source distribution [testenv:build] commands = python setup.py sdist python setup.py bdist_wheel [testenv:build-py37] basepython = python3.7 commands = {[testenv:build]commands} [testenv:build-py38] basepython = python3.8 commands = {[testenv:build]commands} [testenv:build-py39] basepython = python3.9 commands = {[testenv:build]commands} [testenv:build-py39-upgrade-dev] basepython = python3.9 install_command = pip install -U -e . {opts} {packages} deps = -rrequirements-dev.txt commands = {[testenv:build]commands} [testenv:build-py39-upgrade-dev-pre] basepython = python3.9 install_command = pip install -U --pre -e . {opts} {packages} deps = -rrequirements-dev.txt commands = {[testenv:build]commands} [testenv:build-py37-min-req] basepython = python3.7 deps = -rrequirements-dev.txt -rrequirements-min.txt commands = {[testenv:build]commands} # Envs that will test installation from a wheel [testenv:wheelinstall] deps = null commands = python -c "import hdmf" # Envs that will execute gallery tests [testenv:gallery] install_command = pip install -U {opts} {packages} deps = -rrequirements-dev.txt -rrequirements.txt -rrequirements-doc.txt commands = python test.py --example [testenv:gallery-py37] basepython = python3.7 deps = {[testenv:gallery]deps} commands = {[testenv:gallery]commands} [testenv:gallery-py38] basepython = python3.8 deps = {[testenv:gallery]deps} commands = {[testenv:gallery]commands} [testenv:gallery-py39] basepython = python3.9 deps = {[testenv:gallery]deps} commands = {[testenv:gallery]commands} # Test with python 3.9, pinned dev and doc reqs, and upgraded run requirements [testenv:gallery-py39-upgrade-dev] basepython = python3.9 install_command = pip install -U -e . {opts} {packages} deps = -rrequirements-dev.txt -rrequirements-doc.txt commands = {[testenv:gallery]commands} # Test with python 3.9, pinned dev and doc reqs, and pre-release run requirements [testenv:gallery-py39-upgrade-dev-pre] basepython = python3.9 install_command = pip install -U --pre -e . {opts} {packages} deps = -rrequirements-dev.txt -rrequirements-doc.txt commands = {[testenv:gallery]commands} # Test with python 3.7, pinned dev reqs, and minimum run requirements [testenv:gallery-py37-min-req] basepython = python3.7 deps = -rrequirements-dev.txt -rrequirements-min.txt -rrequirements-doc.txt commands = {[testenv:gallery]commands} ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1627603598.0 hdmf-3.1.1/versioneer.py0000644000655200065520000021321100000000000015314 0ustar00circlecicircleci# flake8: noqa: C901 # Version: 0.18 """The Versioneer - like a rocketeer, but for versions. The Versioneer ============== * like a rocketeer, but for versions! * https://github.com/warner/python-versioneer * Brian Warner * License: Public Domain * Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy * [![Latest Version] (https://pypip.in/version/versioneer/badge.svg?style=flat) ](https://pypi.python.org/pypi/versioneer/) * [![Build Status] (https://travis-ci.org/warner/python-versioneer.png?branch=master) ](https://travis-ci.org/warner/python-versioneer) This is a tool for managing a recorded version number in distutils-based python projects. The goal is to remove the tedious and error-prone "update the embedded version string" step from your release process. Making a new release should be as easy as recording a new tag in your version-control system, and maybe making new tarballs. ## Quick Install * `pip install versioneer` to somewhere to your $PATH * add a `[versioneer]` section to your setup.cfg (see below) * run `versioneer install` in your source tree, commit the results ## Version Identifiers Source trees come from a variety of places: * a version-control system checkout (mostly used by developers) * a nightly tarball, produced by build automation * a snapshot tarball, produced by a web-based VCS browser, like github's "tarball from tag" feature * a release tarball, produced by "setup.py sdist", distributed through PyPI Within each source tree, the version identifier (either a string or a number, this tool is format-agnostic) can come from a variety of places: * ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows about recent "tags" and an absolute revision-id * the name of the directory into which the tarball was unpacked * an expanded VCS keyword ($Id$, etc) * a `_version.py` created by some earlier build step For released software, the version identifier is closely related to a VCS tag. Some projects use tag names that include more than just the version string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool needs to strip the tag prefix to extract the version identifier. For unreleased software (between tags), the version identifier should provide enough information to help developers recreate the same tree, while also giving them an idea of roughly how old the tree is (after version 1.2, before version 1.3). Many VCS systems can report a description that captures this, for example `git describe --tags --dirty --always` reports things like "0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the 0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has uncommitted changes. The version identifier is used for multiple purposes: * to allow the module to self-identify its version: `myproject.__version__` * to choose a name and prefix for a 'setup.py sdist' tarball ## Theory of Operation Versioneer works by adding a special `_version.py` file into your source tree, where your `__init__.py` can import it. This `_version.py` knows how to dynamically ask the VCS tool for version information at import time. `_version.py` also contains `$Revision$` markers, and the installation process marks `_version.py` to have this marker rewritten with a tag name during the `git archive` command. As a result, generated tarballs will contain enough information to get the proper version. To allow `setup.py` to compute a version too, a `versioneer.py` is added to the top level of your source tree, next to `setup.py` and the `setup.cfg` that configures it. This overrides several distutils/setuptools commands to compute the version when invoked, and changes `setup.py build` and `setup.py sdist` to replace `_version.py` with a small static file that contains just the generated version data. ## Installation See [INSTALL.md](./INSTALL.md) for detailed installation instructions. ## Version-String Flavors Code which uses Versioneer can learn about its version string at runtime by importing `_version` from your main `__init__.py` file and running the `get_versions()` function. From the "outside" (e.g. in `setup.py`), you can import the top-level `versioneer.py` and run `get_versions()`. Both functions return a dictionary with different flavors of version information: * `['version']`: A condensed version string, rendered using the selected style. This is the most commonly used value for the project's version string. The default "pep440" style yields strings like `0.11`, `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section below for alternative styles. * `['full-revisionid']`: detailed revision identifier. For Git, this is the full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". * `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the commit date in ISO 8601 format. This will be None if the date is not available. * `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that this is only accurate if run in a VCS checkout, otherwise it is likely to be False or None * `['error']`: if the version string could not be computed, this will be set to a string describing the problem, otherwise it will be None. It may be useful to throw an exception in setup.py if this is set, to avoid e.g. creating tarballs with a version string of "unknown". Some variants are more useful than others. Including `full-revisionid` in a bug report should allow developers to reconstruct the exact code being tested (or indicate the presence of local changes that should be shared with the developers). `version` is suitable for display in an "about" box or a CLI `--version` output: it can be easily compared against release notes and lists of bugs fixed in various releases. The installer adds the following text to your `__init__.py` to place a basic version in `YOURPROJECT.__version__`: from ._version import get_versions __version__ = get_versions()['version'] del get_versions ## Styles The setup.cfg `style=` configuration controls how the VCS information is rendered into a version string. The default style, "pep440", produces a PEP440-compliant string, equal to the un-prefixed tag name for actual releases, and containing an additional "local version" section with more detail for in-between builds. For Git, this is TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags --dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and that this commit is two revisions ("+2") beyond the "0.11" tag. For released software (exactly equal to a known tag), the identifier will only contain the stripped tag, e.g. "0.11". Other styles are available. See [details.md](details.md) in the Versioneer source tree for descriptions. ## Debugging Versioneer tries to avoid fatal errors: if something goes wrong, it will tend to return a version of "0+unknown". To investigate the problem, run `setup.py version`, which will run the version-lookup code in a verbose mode, and will display the full contents of `get_versions()` (including the `error` string, which may help identify what went wrong). ## Known Limitations Some situations are known to cause problems for Versioneer. This details the most significant ones. More can be found on Github [issues page](https://github.com/warner/python-versioneer/issues). ### Subprojects Versioneer has limited support for source trees in which `setup.py` is not in the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are two common reasons why `setup.py` might not be in the root: * Source trees which contain multiple subprojects, such as [Buildbot](https://github.com/buildbot/buildbot), which contains both "master" and "slave" subprojects, each with their own `setup.py`, `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI distributions (and upload multiple independently-installable tarballs). * Source trees whose main purpose is to contain a C library, but which also provide bindings to Python (and perhaps other languages) in subdirectories. Versioneer will look for `.git` in parent directories, and most operations should get the right version string. However `pip` and `setuptools` have bugs and implementation details which frequently cause `pip install .` from a subproject directory to fail to find a correct version string (so it usually defaults to `0+unknown`). `pip install --editable .` should work correctly. `setup.py install` might work too. Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in some later version. [Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking this issue. The discussion in [PR #61](https://github.com/warner/python-versioneer/pull/61) describes the issue from the Versioneer side in more detail. [pip PR#3176](https://github.com/pypa/pip/pull/3176) and [pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve pip to let Versioneer work correctly. Versioneer-0.16 and earlier only looked for a `.git` directory next to the `setup.cfg`, so subprojects were completely unsupported with those releases. ### Editable installs with setuptools <= 18.5 `setup.py develop` and `pip install --editable .` allow you to install a project into a virtualenv once, then continue editing the source code (and test) without re-installing after every change. "Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a convenient way to specify executable scripts that should be installed along with the python package. These both work as expected when using modern setuptools. When using setuptools-18.5 or earlier, however, certain operations will cause `pkg_resources.DistributionNotFound` errors when running the entrypoint script, which must be resolved by re-installing the package. This happens when the install happens with one version, then the egg_info data is regenerated while a different version is checked out. Many setup.py commands cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into a different virtualenv), so this can be surprising. [Bug #83](https://github.com/warner/python-versioneer/issues/83) describes this one, but upgrading to a newer version of setuptools should probably resolve it. ### Unicode version strings While Versioneer works (and is continually tested) with both Python 2 and Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. Newer releases probably generate unicode version strings on py2. It's not clear that this is wrong, but it may be surprising for applications when then write these strings to a network connection or include them in bytes-oriented APIs like cryptographic checksums. [Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates this question. ## Updating Versioneer To upgrade your project to a new release of Versioneer, do the following: * install the new Versioneer (`pip install -U versioneer` or equivalent) * edit `setup.cfg`, if necessary, to include any new configuration settings indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. * re-run `versioneer install` in your source tree, to replace `SRC/_version.py` * commit any changed files ## Future Directions This tool is designed to make it easily extended to other version-control systems: all VCS-specific components are in separate directories like src/git/ . The top-level `versioneer.py` script is assembled from these components by running make-versioneer.py . In the future, make-versioneer.py will take a VCS name as an argument, and will construct a version of `versioneer.py` that is specific to the given VCS. It might also take the configuration arguments that are currently provided manually during installation by editing setup.py . Alternatively, it might go the other direction and include code from all supported VCS systems, reducing the number of intermediate scripts. ## License To make Versioneer easier to embed, all its code is dedicated to the public domain. The `_version.py` that it creates is also in the public domain. Specifically, both are released under the Creative Commons "Public Domain Dedication" license (CC0-1.0), as described in https://creativecommons.org/publicdomain/zero/1.0/ . """ from __future__ import print_function try: import configparser except ImportError: import ConfigParser as configparser import errno import fnmatch # HDMF import json import os import re import subprocess import sys class VersioneerConfig: """Container for Versioneer configuration parameters.""" def get_root(): """Get the project root directory. We require that all commands are run from the project root, i.e. the directory that contains setup.py, setup.cfg, and versioneer.py . """ root = os.path.realpath(os.path.abspath(os.getcwd())) setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): # allow 'python path/to/setup.py COMMAND' root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): err = ("Versioneer was unable to run the project root directory. " "Versioneer requires setup.py to be executed from " "its immediate directory (like 'python setup.py COMMAND'), " "or in a way that lets it use sys.argv[0] to find the root " "(like 'python path/to/setup.py COMMAND').") raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools # tree) execute all dependencies in a single python process, so # "versioneer" may be imported multiple times, and python's shared # module-import table will cache the first one. So we can't use # os.path.dirname(__file__), as that will find whichever # versioneer.py was first imported, even in later projects. me = os.path.realpath(os.path.abspath(__file__)) me_dir = os.path.normcase(os.path.splitext(me)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir: print("Warning: build in %s is using versioneer.py from %s" % (os.path.dirname(me), versioneer_py)) except NameError: pass return root def get_config_from_root(root): """Read the project setup.cfg file to determine Versioneer config.""" # This might raise EnvironmentError (if setup.cfg is missing), or # configparser.NoSectionError (if it lacks a [versioneer] section), or # configparser.NoOptionError (if it lacks "VCS="). See the docstring at # the top of versioneer.py for instructions on writing your setup.cfg . setup_cfg = os.path.join(root, "setup.cfg") parser = configparser.SafeConfigParser() with open(setup_cfg, "r") as f: parser.readfp(f) VCS = parser.get("versioneer", "VCS") # mandatory def get(parser, name): if parser.has_option("versioneer", name): return parser.get("versioneer", name) return None cfg = VersioneerConfig() cfg.VCS = VCS cfg.style = get(parser, "style") or "" cfg.versionfile_source = get(parser, "versionfile_source") cfg.versionfile_build = get(parser, "versionfile_build") cfg.tag_prefix = get(parser, "tag_prefix") if cfg.tag_prefix in ("''", '""'): cfg.tag_prefix = "" cfg.parentdir_prefix = get(parser, "parentdir_prefix") cfg.verbose = get(parser, "verbose") return cfg class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" # these dictionaries contain VCS-specific tools LONG_VERSION_PY = {} HANDLERS = {} def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f return decorate def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None for c in commands: try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git p = subprocess.Popen([c] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None)) break except EnvironmentError: e = sys.exc_info()[1] if e.errno == errno.ENOENT: continue if verbose: print("unable to run %s" % dispcmd) print(e) return None, None else: if verbose: print("unable to find command, tried %s" % (commands,)) return None, None stdout = p.communicate()[0].strip() if sys.version_info[0] >= 3: stdout = stdout.decode() if p.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) return None, p.returncode return stdout, p.returncode LONG_VERSION_PY['git'] = ''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. # This file is released into the public domain. Generated by # versioneer-0.18 (https://github.com/warner/python-versioneer) """Git implementation of _version.py.""" import errno import fnmatch # HDMF import os import re import subprocess import sys def get_keywords(): """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must # each be defined on a line of their own. _version.py will just call # get_keywords(). git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} return keywords class VersioneerConfig: """Container for Versioneer configuration parameters.""" def get_config(): """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py cfg = VersioneerConfig() cfg.VCS = "git" cfg.style = "%(STYLE)s" cfg.tag_prefix = "%(TAG_PREFIX)s" cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" cfg.verbose = False return cfg class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" LONG_VERSION_PY = {} HANDLERS = {} def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f return decorate def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None for c in commands: try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git p = subprocess.Popen([c] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None)) break except EnvironmentError: e = sys.exc_info()[1] if e.errno == errno.ENOENT: continue if verbose: print("unable to run %%s" %% dispcmd) print(e) return None, None else: if verbose: print("unable to find command, tried %%s" %% (commands,)) return None, None stdout = p.communicate()[0].strip() if sys.version_info[0] >= 3: stdout = stdout.decode() if p.returncode != 0: if verbose: print("unable to run %%s (error)" %% dispcmd) print("stdout was %%s" %% stdout) return None, p.returncode return stdout, p.returncode def versions_from_parentdir(parentdir_prefix, root, verbose): """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both the project name and a version string. We will also support searching up two directory levels for an appropriately named parent directory """ rootdirs = [] for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: print("Tried directories %%s but none started with prefix %%s" %% (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @register_vcs_handler("git", "get_keywords") def git_get_keywords(versionfile_abs): """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. keywords = {} try: f = open(versionfile_abs, "r") for line in f.readlines(): if line.strip().startswith("git_refnames ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["refnames"] = mo.group(1) if line.strip().startswith("git_full ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["full"] = mo.group(1) if line.strip().startswith("git_date ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["date"] = mo.group(1) f.close() except EnvironmentError: pass return keywords @register_vcs_handler("git", "keywords") def git_versions_from_keywords(keywords, tag_prefix, verbose): """Get version information from git keywords.""" if not keywords: raise NotThisMethod("no keywords at all, weird") date = keywords.get("date") if date is not None: # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because # it's been around since git-1.5.3, and it's too difficult to # discover which version we're using, or to work around using an # older one. date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) refnames = keywords["refnames"].strip() if refnames.startswith("$Format"): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") refs = set([r.strip() for r in refnames.strip("()").split(",")]) # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %%d # expansion behaves like git log --decorate=short and strips out the # refs/heads/ and refs/tags/ prefixes that would let us distinguish # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". tags = set([r for r in refs if re.search(r'\d', r)]) if verbose: print("discarding '%%s', no digits" %% ",".join(refs - tags)) if verbose: print("likely tags: %%s" %% ",".join(sorted(tags))) for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" # HDMF: Support tag_prefix specified as a glob pattern tag_is_glob_pattern = "*" in tag_prefix if tag_is_glob_pattern: if fnmatch.fnmatch(ref, tag_prefix): r = ref if verbose: print("picking %s" % r) return {"version": r, "full-revisionid": keywords["full"].strip(), "dirty": False, "error": None, "date": date} else: if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] if verbose: print("picking %s" % r) return {"version": r, "full-revisionid": keywords["full"].strip(), "dirty": False, "error": None, "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") return {"version": "0+unknown", "full-revisionid": keywords["full"].strip(), "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* expanded, and _version.py hasn't already been rewritten with a short version string, meaning we're inside a checked out source tree. """ GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %%s not under git control" %% root) raise NotThisMethod("'git rev-parse --git-dir' returned error") # HDMF: Support tag_prefix specified as a glob pattern tag_is_glob_pattern = "*" in tag_prefix match_argument = tag_prefix if not tag_is_glob_pattern: match_argument = tag_prefix + "*" # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", "--always", "--long", "--match", "%s" % match_argument], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() pieces = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out # look for -dirty suffix dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%%s'" %% describe_out) return pieces # tag full_tag = mo.group(1) # HDMF: Support tag_prefix specified as a glob pattern if tag_is_glob_pattern: if not fnmatch.fnmatch(full_tag, tag_prefix): if verbose: fmt = "tag '%%s' doesn't match glob pattern '%%s'" print(fmt %% (full_tag, tag_prefix)) pieces["error"] = ("tag '%%s' doesn't match glob pattern '%%s'" %% (full_tag, tag_prefix)) return pieces pieces["closest-tag"] = full_tag else: if not full_tag.startswith(tag_prefix): if verbose: fmt = "tag '%%s' doesn't start with prefix '%%s'" print(fmt %% (full_tag, tag_prefix)) pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" %% (full_tag, tag_prefix)) return pieces pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) # commit: short hex revision ID pieces["short"] = mo.group(3) else: # HEX: no tags pieces["closest-tag"] = None count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces def plus_or_dot(pieces): """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" def render_pep440(pieces): """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty Exceptions: 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += plus_or_dot(pieces) rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered def render_pep440_pre(pieces): """TAG[.post.devDISTANCE] -- No -dirty. Exceptions: 1: no tags. 0.post.devDISTANCE """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"]: rendered += ".post.dev%%d" %% pieces["distance"] else: # exception #1 rendered = "0.post.dev%%d" %% pieces["distance"] return rendered def render_pep440_post(pieces): """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards (a dirty tree will appear "older" than the corresponding clean one), but you shouldn't be releasing software with -dirty anyways. Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%%d" %% pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "g%%s" %% pieces["short"] else: # exception #1 rendered = "0.post%%d" %% pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += "+g%%s" %% pieces["short"] return rendered def render_pep440_old(pieces): """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. Eexceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%%d" %% pieces["distance"] if pieces["dirty"]: rendered += ".dev0" else: # exception #1 rendered = "0.post%%d" %% pieces["distance"] if pieces["dirty"]: rendered += ".dev0" return rendered def render_git_describe(pieces): """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"]: rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render_git_describe_long(pieces): """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. The distance/hash is unconditional. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", "full-revisionid": pieces.get("long"), "dirty": None, "error": pieces["error"], "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": rendered = render_git_describe(pieces) elif style == "git-describe-long": rendered = render_git_describe_long(pieces) else: raise ValueError("unknown style '%%s'" %% style) return {"version": rendered, "full-revisionid": pieces["long"], "dirty": pieces["dirty"], "error": None, "date": pieces.get("date")} def get_versions(): """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which # case we can only use expanded keywords. cfg = get_config() verbose = cfg.verbose try: return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass try: root = os.path.realpath(__file__) # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. for i in cfg.versionfile_source.split('/'): root = os.path.dirname(root) except NameError: return {"version": "0+unknown", "full-revisionid": None, "dirty": None, "error": "unable to find root of source tree", "date": None} try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) return render(pieces, cfg.style) except NotThisMethod: pass try: if cfg.parentdir_prefix: return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) except NotThisMethod: pass return {"version": "0+unknown", "full-revisionid": None, "dirty": None, "error": "unable to compute version", "date": None} ''' @register_vcs_handler("git", "get_keywords") def git_get_keywords(versionfile_abs): """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. keywords = {} try: f = open(versionfile_abs, "r") for line in f.readlines(): if line.strip().startswith("git_refnames ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["refnames"] = mo.group(1) if line.strip().startswith("git_full ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["full"] = mo.group(1) if line.strip().startswith("git_date ="): mo = re.search(r'=\s*"(.*)"', line) if mo: keywords["date"] = mo.group(1) f.close() except EnvironmentError: pass return keywords @register_vcs_handler("git", "keywords") def git_versions_from_keywords(keywords, tag_prefix, verbose): """Get version information from git keywords.""" if not keywords: raise NotThisMethod("no keywords at all, weird") date = keywords.get("date") if date is not None: # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because # it's been around since git-1.5.3, and it's too difficult to # discover which version we're using, or to work around using an # older one. date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) refnames = keywords["refnames"].strip() if refnames.startswith("$Format"): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") refs = set([r.strip() for r in refnames.strip("()").split(",")]) # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d # expansion behaves like git log --decorate=short and strips out the # refs/heads/ and refs/tags/ prefixes that would let us distinguish # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". tags = set([r for r in refs if re.search(r'\d', r)]) if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: print("likely tags: %s" % ",".join(sorted(tags))) for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" # HDMF: Support tag_prefix specified as a glob pattern tag_is_glob_pattern = "*" in tag_prefix if tag_is_glob_pattern: if fnmatch.fnmatch(ref, tag_prefix): r = ref if verbose: print("picking %s" % r) return {"version": r, "full-revisionid": keywords["full"].strip(), "dirty": False, "error": None, "date": date} else: if ref.startswith(tag_prefix): r = ref[len(tag_prefix):] if verbose: print("picking %s" % r) return {"version": r, "full-revisionid": keywords["full"].strip(), "dirty": False, "error": None, "date": date} # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") return {"version": "0+unknown", "full-revisionid": keywords["full"].strip(), "dirty": False, "error": "no suitable tags", "date": None} @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* expanded, and _version.py hasn't already been rewritten with a short version string, meaning we're inside a checked out source tree. """ GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) raise NotThisMethod("'git rev-parse --git-dir' returned error") # HDMF: Support tag_prefix specified as a glob pattern tag_is_glob_pattern = "*" in tag_prefix match_argument = tag_prefix if not tag_is_glob_pattern: match_argument = tag_prefix + "*" # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", "--always", "--long", "--match", "%s" % match_argument], cwd=root) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() pieces = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out # look for -dirty suffix dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: git_describe = git_describe[:git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? pieces["error"] = ("unable to parse git-describe output: '%s'" % describe_out) return pieces # tag full_tag = mo.group(1) # HDMF: Support tag_prefix specified as a glob pattern if tag_is_glob_pattern: if not fnmatch.fnmatch(full_tag, tag_prefix): if verbose: fmt = "tag '%s' doesn't match glob pattern '%s'" print(fmt % (full_tag, tag_prefix)) pieces["error"] = ("tag '%s' doesn't match glob pattern '%s'" % (full_tag, tag_prefix)) return pieces pieces["closest-tag"] = full_tag else: if not full_tag.startswith(tag_prefix): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" % (full_tag, tag_prefix)) return pieces pieces["closest-tag"] = full_tag[len(tag_prefix):] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) # commit: short hex revision ID pieces["short"] = mo.group(3) else: # HEX: no tags pieces["closest-tag"] = None count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces def do_vcs_install(manifest_in, versionfile_source, ipy): """Git-specific installation logic for Versioneer. For Git, this means creating/changing .gitattributes to mark _version.py for export-subst keyword substitution. """ GITS = ["git"] if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] files = [manifest_in, versionfile_source] if ipy: files.append(ipy) try: me = __file__ if me.endswith(".pyc") or me.endswith(".pyo"): me = os.path.splitext(me)[0] + ".py" versioneer_file = os.path.relpath(me) except NameError: versioneer_file = "versioneer.py" files.append(versioneer_file) present = False try: f = open(".gitattributes", "r") for line in f.readlines(): if line.strip().startswith(versionfile_source): if "export-subst" in line.strip().split()[1:]: present = True f.close() except EnvironmentError: pass if not present: f = open(".gitattributes", "a+") f.write("%s export-subst\n" % versionfile_source) f.close() files.append(".gitattributes") run_command(GITS, ["add", "--"] + files) def versions_from_parentdir(parentdir_prefix, root, verbose): """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both the project name and a version string. We will also support searching up two directory levels for an appropriately named parent directory """ rootdirs = [] for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return {"version": dirname[len(parentdir_prefix):], "full-revisionid": None, "dirty": False, "error": None, "date": None} else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: print("Tried directories %s but none started with prefix %s" % (str(rootdirs), parentdir_prefix)) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") SHORT_VERSION_PY = """ # This file was generated by 'versioneer.py' (0.18) from # revision-control system data, or from the parent directory name of an # unpacked source archive. Distribution tarballs contain a pre-generated copy # of this file. import json version_json = ''' %s ''' # END VERSION_JSON def get_versions(): return json.loads(version_json) """ def versions_from_file(filename): """Try to determine the version from _version.py if present.""" try: with open(filename) as f: contents = f.read() except EnvironmentError: raise NotThisMethod("unable to read _version.py") mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) if not mo: mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) def write_to_version_file(filename, versions): """Write the given version number to the given _version.py file.""" os.unlink(filename) contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) print("set %s to '%s'" % (filename, versions["version"])) def plus_or_dot(pieces): """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" def render_pep440(pieces): """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty Exceptions: 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += plus_or_dot(pieces) rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" else: # exception #1 rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered def render_pep440_pre(pieces): """TAG[.post.devDISTANCE] -- No -dirty. Exceptions: 1: no tags. 0.post.devDISTANCE """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"]: rendered += ".post.dev%d" % pieces["distance"] else: # exception #1 rendered = "0.post.dev%d" % pieces["distance"] return rendered def render_pep440_post(pieces): """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards (a dirty tree will appear "older" than the corresponding clean one), but you shouldn't be releasing software with -dirty anyways. Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += plus_or_dot(pieces) rendered += "g%s" % pieces["short"] else: # exception #1 rendered = "0.post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" rendered += "+g%s" % pieces["short"] return rendered def render_pep440_old(pieces): """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. Eexceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"] or pieces["dirty"]: rendered += ".post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" else: # exception #1 rendered = "0.post%d" % pieces["distance"] if pieces["dirty"]: rendered += ".dev0" return rendered def render_git_describe(pieces): """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] if pieces["distance"]: rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render_git_describe_long(pieces): """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. The distance/hash is unconditional. Exceptions: 1: no tags. HEX[-dirty] (note: no 'g' prefix) """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) else: # exception #1 rendered = pieces["short"] if pieces["dirty"]: rendered += "-dirty" return rendered def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: return {"version": "unknown", "full-revisionid": pieces.get("long"), "dirty": None, "error": pieces["error"], "date": None} if not style or style == "default": style = "pep440" # the default if style == "pep440": rendered = render_pep440(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": rendered = render_git_describe(pieces) elif style == "git-describe-long": rendered = render_git_describe_long(pieces) else: raise ValueError("unknown style '%s'" % style) return {"version": rendered, "full-revisionid": pieces["long"], "dirty": pieces["dirty"], "error": None, "date": pieces.get("date")} class VersioneerBadRootError(Exception): """The project root directory is unknown or missing key files.""" def get_versions(verbose=False): """Get the project version from whatever source is available. Returns dict with two keys: 'version' and 'full'. """ if "versioneer" in sys.modules: # see the discussion in cmdclass.py:get_cmdclass() del sys.modules["versioneer"] root = get_root() cfg = get_config_from_root(root) assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or cfg.verbose assert cfg.versionfile_source is not None, \ "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) # extract version from first of: _version.py, VCS command (e.g. 'git # describe'), parentdir. This is meant to work for developers using a # source checkout, for users of a tarball created by 'setup.py sdist', # and for users of a tarball/zipball created by 'git archive' or github's # download-from-tag feature or the equivalent in other VCSes. get_keywords_f = handlers.get("get_keywords") from_keywords_f = handlers.get("keywords") if get_keywords_f and from_keywords_f: try: keywords = get_keywords_f(versionfile_abs) ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) if verbose: print("got version from expanded keyword %s" % ver) return ver except NotThisMethod: pass try: ver = versions_from_file(versionfile_abs) if verbose: print("got version from file %s %s" % (versionfile_abs, ver)) return ver except NotThisMethod: pass from_vcs_f = handlers.get("pieces_from_vcs") if from_vcs_f: try: pieces = from_vcs_f(cfg.tag_prefix, root, verbose) ver = render(pieces, cfg.style) if verbose: print("got version from VCS %s" % ver) return ver except NotThisMethod: pass try: if cfg.parentdir_prefix: ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) if verbose: print("got version from parentdir %s" % ver) return ver except NotThisMethod: pass if verbose: print("unable to compute version") return {"version": "0+unknown", "full-revisionid": None, "dirty": None, "error": "unable to compute version", "date": None} def get_version(): """Get the short version string for this project.""" return get_versions()["version"] def get_cmdclass(): """Get the custom setuptools/distutils subclasses used by Versioneer.""" if "versioneer" in sys.modules: del sys.modules["versioneer"] # this fixes the "python setup.py develop" case (also 'install' and # 'easy_install .'), in which subdependencies of the main project are # built (using setup.py bdist_egg) in the same python process. Assume # a main project A and a dependency B, which use different versions # of Versioneer. A's setup.py imports A's Versioneer, leaving it in # sys.modules by the time B's setup.py is executed, causing B to run # with the wrong versioneer. Setuptools wraps the sub-dep builds in a # sandbox that restores sys.modules to it's pre-build state, so the # parent is protected against the child's "import versioneer". By # removing ourselves from sys.modules here, before the child build # happens, we protect the child from the parent's versioneer too. # Also see https://github.com/warner/python-versioneer/issues/52 cmds = {} # we add "version" to both distutils and setuptools from distutils.core import Command class cmd_version(Command): description = "report generated version string" user_options = [] boolean_options = [] def initialize_options(self): pass def finalize_options(self): pass def run(self): vers = get_versions(verbose=True) print("Version: %s" % vers["version"]) print(" full-revisionid: %s" % vers.get("full-revisionid")) print(" dirty: %s" % vers.get("dirty")) print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) cmds["version"] = cmd_version # we override "build_py" in both distutils and setuptools # # most invocation pathways end up running build_py: # distutils/build -> build_py # distutils/install -> distutils/build ->.. # setuptools/bdist_wheel -> distutils/install ->.. # setuptools/bdist_egg -> distutils/install_lib -> build_py # setuptools/install -> bdist_egg ->.. # setuptools/develop -> ? # pip install: # copies source tree to a tempdir before running egg_info/etc # if .git isn't copied too, 'git describe' will fail # then does setup.py bdist_wheel, or sometimes setup.py install # setup.py egg_info -> ? # we override different "build_py" commands for both environments if "setuptools" in sys.modules: from setuptools.command.build_py import build_py as _build_py else: from distutils.command.build_py import build_py as _build_py class cmd_build_py(_build_py): def run(self): root = get_root() cfg = get_config_from_root(root) versions = get_versions() _build_py.run(self) # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) cmds["build_py"] = cmd_build_py if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION # "product_version": versioneer.get_version(), # ... class cmd_build_exe(_build_exe): def run(self): root = get_root() cfg = get_config_from_root(root) versions = get_versions() target_versionfile = cfg.versionfile_source print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) _build_exe.run(self) os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] f.write(LONG % {"DOLLAR": "$", "STYLE": cfg.style, "TAG_PREFIX": cfg.tag_prefix, "PARENTDIR_PREFIX": cfg.parentdir_prefix, "VERSIONFILE_SOURCE": cfg.versionfile_source, }) cmds["build_exe"] = cmd_build_exe del cmds["build_py"] if 'py2exe' in sys.modules: # py2exe enabled? try: from py2exe.distutils_buildexe import py2exe as _py2exe # py3 except ImportError: from py2exe.build_exe import py2exe as _py2exe # py2 class cmd_py2exe(_py2exe): def run(self): root = get_root() cfg = get_config_from_root(root) versions = get_versions() target_versionfile = cfg.versionfile_source print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) _py2exe.run(self) os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] f.write(LONG % {"DOLLAR": "$", "STYLE": cfg.style, "TAG_PREFIX": cfg.tag_prefix, "PARENTDIR_PREFIX": cfg.parentdir_prefix, "VERSIONFILE_SOURCE": cfg.versionfile_source, }) cmds["py2exe"] = cmd_py2exe # we override different "sdist" commands for both environments if "setuptools" in sys.modules: from setuptools.command.sdist import sdist as _sdist else: from distutils.command.sdist import sdist as _sdist class cmd_sdist(_sdist): def run(self): versions = get_versions() self._versioneer_generated_versions = versions # unless we update this, the command will keep using the old # version self.distribution.metadata.version = versions["version"] return _sdist.run(self) def make_release_tree(self, base_dir, files): root = get_root() cfg = get_config_from_root(root) _sdist.make_release_tree(self, base_dir, files) # now locate _version.py in the new base_dir directory # (remembering that it may be a hardlink) and replace it with an # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, self._versioneer_generated_versions) cmds["sdist"] = cmd_sdist return cmds CONFIG_ERROR = """ setup.cfg is missing the necessary Versioneer configuration. You need a section like: [versioneer] VCS = git style = pep440 versionfile_source = src/myproject/_version.py versionfile_build = myproject/_version.py tag_prefix = parentdir_prefix = myproject- You will also need to edit your setup.py to use the results: import versioneer setup(version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), ...) Please read the docstring in ./versioneer.py for configuration instructions, edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. """ SAMPLE_CONFIG = """ # See the docstring in versioneer.py for instructions. Note that you must # re-run 'versioneer.py setup' after changing this section, and commit the # resulting files. [versioneer] #VCS = git #style = pep440 #versionfile_source = #versionfile_build = #tag_prefix = #parentdir_prefix = """ INIT_PY_SNIPPET = """ from ._version import get_versions __version__ = get_versions()['version'] del get_versions """ def do_setup(): """Main VCS-independent setup function for installing Versioneer.""" root = get_root() try: cfg = get_config_from_root(root) except (EnvironmentError, configparser.NoSectionError, configparser.NoOptionError) as e: if isinstance(e, (EnvironmentError, configparser.NoSectionError)): print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) return 1 print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] f.write(LONG % {"DOLLAR": "$", "STYLE": cfg.style, "TAG_PREFIX": cfg.tag_prefix, "PARENTDIR_PREFIX": cfg.parentdir_prefix, "VERSIONFILE_SOURCE": cfg.versionfile_source, }) ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") if os.path.exists(ipy): try: with open(ipy, "r") as f: old = f.read() except EnvironmentError: old = "" if INIT_PY_SNIPPET not in old: print(" appending to %s" % ipy) with open(ipy, "a") as f: f.write(INIT_PY_SNIPPET) else: print(" %s unmodified" % ipy) else: print(" %s doesn't exist, ok" % ipy) ipy = None # Make sure both the top-level "versioneer.py" and versionfile_source # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so # they'll be copied into source distributions. Pip won't be able to # install the package without this. manifest_in = os.path.join(root, "MANIFEST.in") simple_includes = set() try: with open(manifest_in, "r") as f: for line in f: if line.startswith("include "): for include in line.split()[1:]: simple_includes.add(include) except EnvironmentError: pass # That doesn't cover everything MANIFEST.in can do # (http://docs.python.org/2/distutils/sourcedist.html#commands), so # it might give some false negatives. Appending redundant 'include' # lines is safe, though. if "versioneer.py" not in simple_includes: print(" appending 'versioneer.py' to MANIFEST.in") with open(manifest_in, "a") as f: f.write("include versioneer.py\n") else: print(" 'versioneer.py' already in MANIFEST.in") if cfg.versionfile_source not in simple_includes: print(" appending versionfile_source ('%s') to MANIFEST.in" % cfg.versionfile_source) with open(manifest_in, "a") as f: f.write("include %s\n" % cfg.versionfile_source) else: print(" versionfile_source already in MANIFEST.in") # Make VCS-specific changes. For git, this means creating/changing # .gitattributes to mark _version.py for export-subst keyword # substitution. do_vcs_install(manifest_in, cfg.versionfile_source, ipy) return 0 def scan_setup_py(): """Validate the contents of setup.py against Versioneer's expectations.""" found = set() setters = False errors = 0 with open("setup.py", "r") as f: for line in f.readlines(): if "import versioneer" in line: found.add("import") if "versioneer.get_cmdclass()" in line: found.add("cmdclass") if "versioneer.get_version()" in line: found.add("get_version") if "versioneer.VCS" in line: setters = True if "versioneer.versionfile_source" in line: setters = True if len(found) != 3: print("") print("Your setup.py appears to be missing some important items") print("(but I might be wrong). Please make sure it has something") print("roughly like the following:") print("") print(" import versioneer") print(" setup( version=versioneer.get_version(),") print(" cmdclass=versioneer.get_cmdclass(), ...)") print("") errors += 1 if setters: print("You should remove lines like 'versioneer.VCS = ' and") print("'versioneer.versionfile_source = ' . This configuration") print("now lives in setup.cfg, and should be removed from setup.py") print("") errors += 1 return errors if __name__ == "__main__": cmd = sys.argv[1] if cmd == "setup": errors = do_setup() errors += scan_setup_py() if errors: sys.exit(1)