pax_global_header00006660000000000000000000000064142610405340014511gustar00rootroot0000000000000052 comment=e5fe944127cab62472a9657f0a5bb232e0303fc8 catalogue-2.1.0/000077500000000000000000000000001426104053400134555ustar00rootroot00000000000000catalogue-2.1.0/.gitignore000066400000000000000000000015061426104053400154470ustar00rootroot00000000000000tmp/ .pytest_cache .vscode .mypy_cache .prettierrc .python-version # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python .env/ env/ build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ *.egg-info/ .installed.cfg *.egg # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *,cover .hypothesis/ # Translations *.mo *.pot # Django stuff: *.log # Sphinx documentation docs/_build/ # PyBuilder target/ #Ipython Notebook .ipynb_checkpoints catalogue-2.1.0/LICENSE000066400000000000000000000020611426104053400144610ustar00rootroot00000000000000MIT License Copyright (c) 2019 ExplosionAI GmbH Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. catalogue-2.1.0/MANIFEST.in000066400000000000000000000000201426104053400152030ustar00rootroot00000000000000include LICENSE catalogue-2.1.0/README.md000066400000000000000000000717361426104053400147520ustar00rootroot00000000000000 # catalogue: Lightweight function registries and configurations for your library `catalogue` is a small library that - makes it easy to **add function (or object) registries** to your code - offers a **configuration system** letting you conveniently describe arbitrary trees of objects. _Function registries_ are helpful when you have objects that need to be both easily serializable and fully customizable. Instead of passing a function into your object, you pass in an identifier name, which the object can use to lookup the function from the registry. This makes the object easy to serialize, because the name is a simple string. If you instead saved the function, you'd have to use Pickle for serialization, which has many drawbacks. _Configuration_ is a huge challenge for machine-learning code because you may want to expose almost any detail of any function as a hyperparameter. The setting you want to expose might be arbitrarily far down in your call stack, so it might need to pass all the way through the CLI or REST API, through any number of intermediate functions, affecting the interface of everything along the way. And then once those settings are added, they become hard to remove later. Default values also become hard to change without breaking backwards compatibility. To solve this problem, `catalogue` offers a config system that lets you easily describe arbitrary trees of objects. The objects can be created via function calls you register using a simple decorator syntax. You can even version the functions you create, allowing you to make improvements without breaking backwards compatibility. The most similar config system we’re aware of is [Gin](https://github.com/google/gin-config), which uses a similar syntax, and also allows you to link the configuration system to functions in your code using a decorator. `catalogue`'s config system is simpler and emphasizes a different workflow via a subset of Gin’s functionality. [![Azure Pipelines](https://img.shields.io/azure-devops/build/explosion-ai/public/14/master.svg?logo=azure-pipelines&style=flat-square&label=build)](https://dev.azure.com/explosion-ai/public/_build?definitionId=14) [![Current Release Version](https://img.shields.io/github/v/release/explosion/catalogue.svg?style=flat-square&include_prereleases&logo=github)](https://github.com/explosion/catalogue/releases) [![pypi Version](https://img.shields.io/pypi/v/catalogue.svg?style=flat-square&logo=pypi&logoColor=white)](https://pypi.org/project/catalogue/) [![conda Version](https://img.shields.io/conda/vn/conda-forge/catalogue.svg?style=flat-square&logo=conda-forge&logoColor=white)](https://anaconda.org/conda-forge/catalogue) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square)](https://github.com/ambv/black) ## ⏳ Installation ```bash pip install catalogue ``` ```bash conda install -c conda-forge catalogue ``` > ⚠️ **Important note:** `catalogue` v2.0+ is only compatible with Python 3.6+. > For Python 2.7+ compatibility, use `catalogue` v1.x. ## 👩‍💻 Usage ### Function registry Let's imagine you're developing a Python package that needs to load data somewhere. You've already implemented some loader functions for the most common data types, but you want to allow the user to easily add their own. Using `catalogue.create` you can create a new registry under the namespace `your_package` → `loaders`. ```python # YOUR PACKAGE import catalogue loaders = catalogue.create("your_package", "loaders") ``` This gives you a `loaders.register` decorator that your users can import and decorate their custom loader functions with. ```python # USER CODE from your_package import loaders @loaders.register("custom_loader") def custom_loader(data): # Load something here... return data ``` The decorated function will be registered automatically and in your package, you'll be able to access all loaders by calling `loaders.get_all`. ```python # YOUR PACKAGE def load_data(data, loader_id): print("All loaders:", loaders.get_all()) # {"custom_loader": } loader = loaders.get(loader_id) return loader(data) ``` The user can now refer to their custom loader using only its string name (`"custom_loader"`) and your application will know what to do and will use their custom function. ```python # USER CODE from your_package import load_data load_data(data, loader_id="custom_loader") ``` ### Configurations The configuration system parses a `.cfg` file like ```ini [training] patience = 10 dropout = 0.2 use_vectors = false [training.logging] level = "INFO" [nlp] # This uses the value of training.use_vectors use_vectors = ${training.use_vectors} lang = "en" ``` and resolves it to a `Dict`: ```json { "training": { "patience": 10, "dropout": 0.2, "use_vectors": false, "logging": { "level": "INFO" } }, "nlp": { "use_vectors": false, "lang": "en" } } ``` The config is divided into sections, with the section name in square brackets – for example, `[training]`. Within the sections, config values can be assigned to keys using `=`. Values can also be referenced from other sections using the dot notation and placeholders indicated by the dollar sign and curly braces. For example, `${training.use_vectors}` will receive the value of use_vectors in the training block. This is useful for settings that are shared across components. The config format has three main differences from Python’s built-in configparser: 1. JSON-formatted values. `catalogue` passes all values through `json.loads` to interpret them. You can use atomic values like strings, floats, integers or booleans, or you can use complex objects such as lists or maps. 2. Structured sections. `catalogue` uses a dot notation to build nested sections. If you have a section named `[section.subsection]`, `catalogue` will parse that into a nested structure, placing subsection within section. 3. References to registry functions. If a key starts with `@`, `catalogue` will interpret its value as the name of a function registry, load the function registered for that name and pass in the rest of the block as arguments. If type hints are available on the function, the argument values (and return value of the function) will be validated against them. This lets you express complex configurations, like a training pipeline where `batch_size` is populated by a function that yields floats. There’s no pre-defined scheme you have to follow; how you set up the top-level sections is up to you. At the end of it, you’ll receive a dictionary with the values that you can use in your script – whether it’s complete initialized functions, or just basic settings. For instance, let’s say you want to define a new optimizer. You'd define its arguments in `config.cfg` like so: ```ini [optimizer] @optimizers = "my_cool_optimizer.v1" learn_rate = 0.001 gamma = 1e-8 ``` To load and parse this configuration: ```python import dataclasses from typing import Union, Iterable from catalogue import catalogue_registry, Config # Create a new registry. catalogue_registry.create("optimizers") # Define a dummy optimizer class. @dataclasses.dataclass class MyCoolOptimizer: learn_rate: float gamma: float @catalogue_registry.optimizers.register("my_cool_optimizer.v1") def make_my_optimizer(learn_rate: Union[float, Iterable[float]], gamma: float): return MyCoolOptimizer(learn_rate, gamma) # Load the config file from disk, resolve it and fetch the instantiated optimizer object. config = Config().from_disk("./config.cfg") resolved = catalogue_registry.resolve(config) optimizer = resolved["optimizer"] # MyCoolOptimizer(learn_rate=0.001, gamma=1e-08) ``` Under the hood, `catalogue` will look up the `"my_cool_optimizer.v1"` function in the "optimizers" registry and then call it with the arguments `learn_rate` and `gamma`. If the function has type annotations, it will also validate the input. For instance, if `learn_rate` is annotated as a float and the config defines a string, `catalogue` will raise an error. The Thinc documentation offers further information on the configuration system: - [recursive blocks](https://thinc.ai/docs/usage-config#registry-recursive) - [defining variable positional arguments](https://thinc.ai/docs/usage-config#registries-args) - [using interpolation](https://thinc.ai/docs/usage-config#config-interpolation) - [using custom registries](https://thinc.ai/docs/usage-config#registries-custom) - [advanced type annotations with Pydantic](https://thinc.ai/docs/usage-config#advanced-types) - [using base schemas](https://thinc.ai/docs/usage-config#advanced-types-base-schema) - [filling a configuration with defaults](https://thinc.ai/docs/usage-config#advanced-types-fill-defaults) ## ❓ FAQ #### But can't the user just pass in the `custom_loader` function directly? Sure, that's the more classic callback approach. Instead of a string ID, `load_data` could also take a function, in which case you wouldn't need a package like this. `catalogue` helps you when you need to produce a serializable record of which functions were passed in. For instance, you might want to write a log message, or save a config to load back your object later. With `catalogue`, your functions can be parameterized by strings, so logging and serialization remains easy – while still giving you full extensibility. #### How do I make sure all of the registration decorators have run? Decorators normally run when modules are imported. Relying on this side-effect can sometimes lead to confusion, especially if there's no other reason the module would be imported. One solution is to use [entry points](https://packaging.python.org/specifications/entry-points/). For instance, in [spaCy](https://spacy.io) we're starting to use function registries to make the pipeline components much more customizable. Let's say one user, Jo, develops a better tagging model using new machine learning research. End-users of Jo's package should be able to write `spacy.load("jo_tagging_model")`. They shouldn't need to remember to write `import jos_tagged_model` first, just to run the function registries as a side-effect. With entry points, the registration happens at install time – so you don't need to rely on the import side-effects. ## 🎛 API ### Registry #### function `catalogue.create` Create a new registry for a given namespace. Returns a setter function that can be used as a decorator or called with a name and `func` keyword argument. If `entry_points=True` is set, the registry will check for [Python entry points](https://packaging.python.org/tutorials/packaging-projects/#entry-points) advertised for the given namespace, e.g. the entry point group `spacy_architectures` for the namespace `"spacy", "architectures"`, in `Registry.get` and `Registry.get_all`. This allows other packages to auto-register functions. | Argument | Type | Description | | -------------- | ---------- | ---------------------------------------------------------------------------------------------- | | `*namespace` | `str` | The namespace, e.g. `"spacy"` or `"spacy", "architectures"`. | | `entry_points` | `bool` | Whether to check for entry points of the given namespace and pre-populate the global registry. | | **RETURNS** | `Registry` | The `Registry` object with methods to register and retrieve functions. | ```python architectures = catalogue.create("spacy", "architectures") # Use as decorator @architectures.register("custom_architecture") def custom_architecture(): pass # Use as regular function architectures.register("custom_architecture", func=custom_architecture) ``` #### class `Registry` The registry object that can be used to register and retrieve functions. It's usually created internally when you call `catalogue.create`. ##### method `Registry.__init__` Initialize a new registry. If `entry_points=True` is set, the registry will check for [Python entry points](https://packaging.python.org/tutorials/packaging-projects/#entry-points) advertised for the given namespace, e.g. the entry point group `spacy_architectures` for the namespace `"spacy", "architectures"`, in `Registry.get` and `Registry.get_all`. | Argument | Type | Description | | -------------- | ------------ | -------------------------------------------------------------------------------- | | `namespace` | `Tuple[str]` | The namespace, e.g. `"spacy"` or `"spacy", "architectures"`. | | `entry_points` | `bool` | Whether to check for entry points of the given namespace in `get` and `get_all`. | | **RETURNS** | `Registry` | The newly created object. | ```python # User-facing API architectures = catalogue.create("spacy", "architectures") # Internal API architectures = Registry(("spacy", "architectures")) ``` ##### method `Registry.__contains__` Check whether a name is in the registry. | Argument | Type | Description | | ----------- | ------ | ------------------------------------ | | `name` | `str` | The name to check. | | **RETURNS** | `bool` | Whether the name is in the registry. | ```python architectures = catalogue.create("spacy", "architectures") @architectures.register("custom_architecture") def custom_architecture(): pass assert "custom_architecture" in architectures ``` ##### method `Registry.__call__` Register a function in the registry's namespace. Can be used as a decorator or called as a function with the `func` keyword argument supplying the function to register. Delegates to `Registry.register`. ##### method `Registry.register` Register a function in the registry's namespace. Can be used as a decorator or called as a function with the `func` keyword argument supplying the function to register. | Argument | Type | Description | | ----------- | ---------- | --------------------------------------------------------- | | `name` | `str` | The name to register under the namespace. | | `func` | `Any` | Optional function to register (if not used as decorator). | | **RETURNS** | `Callable` | The decorator that takes one argument, the name. | ```python architectures = catalogue.create("spacy", "architectures") # Use as decorator @architectures.register("custom_architecture") def custom_architecture(): pass # Use as regular function architectures.register("custom_architecture", func=custom_architecture) ``` ##### method `Registry.get` Get a function registered in the namespace. | Argument | Type | Description | | ----------- | ----- | ------------------------ | | `name` | `str` | The name. | | **RETURNS** | `Any` | The registered function. | ```python custom_architecture = architectures.get("custom_architecture") ``` ##### method `Registry.get_all` Get all functions in the registry's namespace. | Argument | Type | Description | | ----------- | ---------------- | ---------------------------------------- | | **RETURNS** | `Dict[str, Any]` | The registered functions, keyed by name. | ```python all_architectures = architectures.get_all() # {"custom_architecture": } ``` ##### method `Registry.get_entry_points` Get registered entry points from other packages for this namespace. The name of the entry point group is the namespace joined by `_`. | Argument | Type | Description | | ----------- | ---------------- | --------------------------------------- | | **RETURNS** | `Dict[str, Any]` | The loaded entry points, keyed by name. | ```python architectures = catalogue.create("spacy", "architectures", entry_points=True) # Will get all entry points of the group "spacy_architectures" all_entry_points = architectures.get_entry_points() ``` ##### method `Registry.get_entry_point` Check if registered entry point is available for a given name in the namespace and load it. Otherwise, return the default value. | Argument | Type | Description | | ----------- | ----- | ------------------------------------------------ | | `name` | `str` | Name of entry point to load. | | `default` | `Any` | The default value to return. Defaults to `None`. | | **RETURNS** | `Any` | The loaded entry point or the default value. | ```python architectures = catalogue.create("spacy", "architectures", entry_points=True) # Will get entry point "custom_architecture" of the group "spacy_architectures" custom_architecture = architectures.get_entry_point("custom_architecture") ``` ##### method `Registry.find` Find the information about a registered function, including the module and path to the file it's defined in, the line number and the docstring, if available. | Argument | Type | Description | | ----------- | ---------------------------- | ----------------------------------- | | `name` | `str` | Name of the registered function. | | **RETURNS** | `Dict[str, Union[str, int]]` | The information about the function. | ```python import catalogue architectures = catalogue.create("spacy", "architectures", entry_points=True) @architectures("my_architecture") def my_architecture(): """This is an architecture""" pass info = architectures.find("my_architecture") # {'module': 'your_package.architectures', # 'file': '/path/to/your_package/architectures.py', # 'line_no': 5, # 'docstring': 'This is an architecture'} ``` #### function `catalogue.check_exists` Check if a namespace exists. | Argument | Type | Description | | ------------ | ------ | ------------------------------------------------------------ | | `*namespace` | `str` | The namespace, e.g. `"spacy"` or `"spacy", "architectures"`. | | **RETURNS** | `bool` | Whether the namespace exists. | ### Config #### class `Config` This class holds the model and training [configuration](https://thinc.ai/docs/usage-config) and can load and save the INI-style configuration format from/to a string, file or bytes. The `Config` class is a subclass of `dict` and uses Python’s `ConfigParser` under the hood. ##### method `Config.__init__` Initialize a new `Config` object with optional data. ```python from catalogue import Config config = Config({"training": {"patience": 10, "dropout": 0.2}}) ``` | Argument | Type | Description | | ----------------- | ----------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | | `data` | `Optional[Union[Dict[str, Any], Config]]` | Optional data to initialize the config with. | | `section_order` | `Optional[List[str]]` | Top-level section names, in order, used to sort the saved and loaded config. All other sections will be sorted alphabetically. | | `is_interpolated` | `Optional[bool]` | Whether the config is interpolated or whether it contains variables. Read from the `data` if it’s an instance of `Config` and otherwise defaults to `True`. | ##### method `Config.from_str` Load the config from a string. ```python from catalogue import Config config_str = """ [training] patience = 10 dropout = 0.2 """ config = Config().from_str(config_str) print(config["training"]) # {'patience': 10, 'dropout': 0.2}} ``` | Argument | Type | Description | | ------------- | ---------------- | -------------------------------------------------------------------------------------------------------------------- | | `text` | `str` | The string config to load. | | `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. | | `overrides` | `Dict[str, Any]` | Overrides for values and sections. Keys are provided in dot notation, e.g. `"training.dropout"` mapped to the value. | | **RETURNS** | `Config` | The loaded config. | ##### method `Config.to_str` Load the config from a string. ```python from catalogue import Config config = Config({"training": {"patience": 10, "dropout": 0.2}}) print(config.to_str()) # '[training]\npatience = 10\n\ndropout = 0.2' ``` | Argument | Type | Description | | ------------- | ------ | --------------------------------------------------------------------------- | | `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. | | **RETURNS** | `str` | The string config. | ##### method `Config.to_bytes` Serialize the config to a byte string. ```python from catalogue import Config config = Config({"training": {"patience": 10, "dropout": 0.2}}) config_bytes = config.to_bytes() print(config_bytes) # b'[training]\npatience = 10\n\ndropout = 0.2' ``` | Argument | Type | Description | | ------------- | ---------------- | -------------------------------------------------------------------------------------------------------------------- | | `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. | | `overrides` | `Dict[str, Any]` | Overrides for values and sections. Keys are provided in dot notation, e.g. `"training.dropout"` mapped to the value. | | **RETURNS** | `str` | The serialized config. | ##### method `Config.from_bytes` Load the config from a byte string. ```python from catalogue import Config config = Config({"training": {"patience": 10, "dropout": 0.2}}) config_bytes = config.to_bytes() new_config = Config().from_bytes(config_bytes) ``` | Argument | Type | Description | | ------------- | -------- | --------------------------------------------------------------------------- | | `bytes_data` | `bool` | The data to load. | | `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. | | **RETURNS** | `Config` | The loaded config. | ##### method `Config.to_disk` Serialize the config to a file. ```python from catalogue import Config config = Config({"training": {"patience": 10, "dropout": 0.2}}) config.to_disk("./config.cfg") ``` | Argument | Type | Description | | ------------- | ------------------ | --------------------------------------------------------------------------- | | `path` | `Union[Path, str]` | The file path. | | `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. | ##### method `Config.from_disk` Load the config from a file. ```python from catalogue import Config config = Config({"training": {"patience": 10, "dropout": 0.2}}) config.to_disk("./config.cfg") new_config = Config().from_disk("./config.cfg") ``` | Argument | Type | Description | | ------------- | ------------------ | -------------------------------------------------------------------------------------------------------------------- | | `path` | `Union[Path, str]` | The file path. | | `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. | | `overrides` | `Dict[str, Any]` | Overrides for values and sections. Keys are provided in dot notation, e.g. `"training.dropout"` mapped to the value. | | **RETURNS** | `Config` | The loaded config. | ##### method `Config.copy` Deep-copy the config. | Argument | Type | Description | | ----------- | -------- | ------------------ | | **RETURNS** | `Config` | The copied config. | ##### method `Config.interpolate` Interpolate variables like `${section.value}` or `${section.subsection}` and return a copy of the config with interpolated values. Can be used if a config is loaded with `interpolate=False`, e.g. via `Config.from_str`. ```python from catalogue import Config config_str = """ [hyper_params] dropout = 0.2 [training] dropout = ${hyper_params.dropout} """ config = Config().from_str(config_str, interpolate=False) print(config["training"]) # {'dropout': '${hyper_params.dropout}'}} config = config.interpolate() print(config["training"]) # {'dropout': 0.2}} ``` | Argument | Type | Description | | ----------- | -------- | ---------------------------------------------- | | **RETURNS** | `Config` | A copy of the config with interpolated values. | ##### method `Config.merge` Deep-merge two config objects, using the current config as the default. Only merges sections and dictionaries and not other values like lists. Values that are provided in the updates are overwritten in the base config, and any new values or sections are added. If a config value is a variable like `${section.key}` (e.g. if the config was loaded with `interpolate=False)`, **the variable is preferred**, even if the updates provide a different value. This ensures that variable references aren’t destroyed by a merge. > :warning: Note that blocks that refer to registered functions using the `@` syntax are only merged if they are > referring to the same functions. Otherwise, merging could easily produce invalid configs, since different functions > can take different arguments. If a block refers to a different function, it’s overwritten. ```python from catalogue import Config base_config_str = """ [training] patience = 10 dropout = 0.2 """ update_config_str = """ [training] dropout = 0.1 max_epochs = 2000 """ base_config = Config().from_str(base_config_str) update_config = Config().from_str(update_config_str) merged = Config(base_config).merge(update_config) print(merged["training"]) # {'patience': 10, 'dropout': 1.0, 'max_epochs': 2000} ``` | Argument | Type | Description | | ----------- | ------------------------------- | --------------------------------------------------- | | `overrides` | `Union[Dict[str, Any], Config]` | The updates to merge into the config. | | **RETURNS** | `Config` | A new config instance containing the merged config. | #### Config Attributes | Argument | Type | Description | | ----------------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | `is_interpolated` | `bool` | Whether the config values have been interpolated. Defaults to `True` and is set to `False` if a config is loaded with `interpolate=False`, e.g. using `Config.from_str`. | catalogue-2.1.0/azure-pipelines.yml000066400000000000000000000043141426104053400173160ustar00rootroot00000000000000trigger: batch: true branches: include: - '*' jobs: - job: 'Test' strategy: matrix: Python36Linux: imageName: 'ubuntu-latest' python.version: '3.6' Python36Windows: imageName: 'windows-2019' python.version: '3.6' Python36Mac: imageName: 'macos-10.15' python.version: '3.6' Python37Linux: imageName: 'ubuntu-latest' python.version: '3.7' Python37Windows: imageName: 'windows-latest' python.version: '3.7' Python37Mac: imageName: 'macos-latest' python.version: '3.7' Python38Linux: imageName: 'ubuntu-latest' python.version: '3.8' Python38Windows: imageName: 'windows-latest' python.version: '3.8' Python38Mac: imageName: 'macos-latest' python.version: '3.8' Python39Linux: imageName: 'ubuntu-latest' python.version: '3.9' Python39Windows: imageName: 'windows-latest' python.version: '3.9' Python39Mac: imageName: 'macos-latest' python.version: '3.9' Python310Linux: imageName: 'ubuntu-latest' python.version: '3.10' Python310Windows: imageName: 'windows-latest' python.version: '3.10' Python310Mac: imageName: 'macos-latest' python.version: '3.10' maxParallel: 4 pool: vmImage: $(imageName) steps: - task: UsePythonVersion@0 inputs: versionSpec: '$(python.version)' architecture: 'x64' - script: | pip install -U -r requirements.txt pip install numpy pathy python setup.py sdist displayName: 'Build sdist' - script: python -m mypy catalogue displayName: 'Run mypy' - task: DeleteFiles@1 inputs: contents: 'catalogue' displayName: 'Delete source directory' - bash: | SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1) pip install dist/$SDIST displayName: 'Install from sdist' - script: python -m pytest --pyargs catalogue displayName: 'Run tests' - bash: | pip install hypothesis python -c "import catalogue; import hypothesis" displayName: 'Test for conflicts' catalogue-2.1.0/bin/000077500000000000000000000000001426104053400142255ustar00rootroot00000000000000catalogue-2.1.0/bin/push-tags.sh000077500000000000000000000005371426104053400165040ustar00rootroot00000000000000#!/usr/bin/env bash set -e # Insist repository is clean git diff-index --quiet HEAD git checkout $1 git pull origin $1 git push origin $1 version=$(grep "version = " setup.cfg) version=${version/version = } version=${version/\'/} version=${version/\'/} version=${version/\"/} version=${version/\"/} git tag "v$version" git push origin "v$version"catalogue-2.1.0/catalogue/000077500000000000000000000000001426104053400154215ustar00rootroot00000000000000catalogue-2.1.0/catalogue/__init__.py000066400000000000000000000001001426104053400175210ustar00rootroot00000000000000from catalogue.registry import * from catalogue.config import * catalogue-2.1.0/catalogue/_importlib_metadata/000077500000000000000000000000001426104053400214215ustar00rootroot00000000000000catalogue-2.1.0/catalogue/_importlib_metadata/LICENSE000066400000000000000000000010731426104053400224270ustar00rootroot00000000000000Copyright 2017-2019 Jason R. Coombs, Barry Warsaw Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. catalogue-2.1.0/catalogue/_importlib_metadata/__init__.py000066400000000000000000000473541426104053400235470ustar00rootroot00000000000000import os import re import abc import csv import sys import zipp import email import pathlib import operator import functools import itertools import posixpath import collections from ._compat import ( NullFinder, PyPy_repr, install, Protocol, ) from configparser import ConfigParser from contextlib import suppress from importlib import import_module from importlib.abc import MetaPathFinder from itertools import starmap from typing import Any, List, Mapping, TypeVar, Union __all__ = [ 'Distribution', 'DistributionFinder', 'PackageNotFoundError', 'distribution', 'distributions', 'entry_points', 'files', 'metadata', 'requires', 'version', ] class PackageNotFoundError(ModuleNotFoundError): """The package was not found.""" def __str__(self): tmpl = "No package metadata was found for {self.name}" return tmpl.format(**locals()) @property def name(self): (name,) = self.args return name class EntryPoint( PyPy_repr, collections.namedtuple('EntryPointBase', 'name value group') ): """An entry point as defined by Python packaging conventions. See `the packaging docs on entry points `_ for more information. """ pattern = re.compile( r'(?P[\w.]+)\s*' r'(:\s*(?P[\w.]+))?\s*' r'(?P\[.*\])?\s*$' ) """ A regular expression describing the syntax for an entry point, which might look like: - module - package.module - package.module:attribute - package.module:object.attribute - package.module:attr [extra1, extra2] Other combinations are possible as well. The expression is lenient about whitespace around the ':', following the attr, and following any extras. """ def load(self): """Load the entry point from its definition. If only a module is indicated by the value, return that module. Otherwise, return the named object. """ match = self.pattern.match(self.value) module = import_module(match.group('module')) attrs = filter(None, (match.group('attr') or '').split('.')) return functools.reduce(getattr, attrs, module) @property def module(self): match = self.pattern.match(self.value) return match.group('module') @property def attr(self): match = self.pattern.match(self.value) return match.group('attr') @property def extras(self): match = self.pattern.match(self.value) return list(re.finditer(r'\w+', match.group('extras') or '')) @classmethod def _from_config(cls, config): return [ cls(name, value, group) for group in config.sections() for name, value in config.items(group) ] @classmethod def _from_text(cls, text): config = ConfigParser(delimiters='=') # case sensitive: https://stackoverflow.com/q/1611799/812183 config.optionxform = str config.read_string(text) return EntryPoint._from_config(config) def __iter__(self): """ Supply iter so one may construct dicts of EntryPoints easily. """ return iter((self.name, self)) def __reduce__(self): return ( self.__class__, (self.name, self.value, self.group), ) class PackagePath(pathlib.PurePosixPath): """A reference to a path in a package""" def read_text(self, encoding='utf-8'): with self.locate().open(encoding=encoding) as stream: return stream.read() def read_binary(self): with self.locate().open('rb') as stream: return stream.read() def locate(self): """Return a path-like object for this path""" return self.dist.locate_file(self) class FileHash: def __init__(self, spec): self.mode, _, self.value = spec.partition('=') def __repr__(self): return ''.format(self.mode, self.value) _T = TypeVar("_T") class PackageMetadata(Protocol): def __len__(self) -> int: ... # pragma: no cover def __contains__(self, item: str) -> bool: ... # pragma: no cover def __getitem__(self, key: str) -> str: ... # pragma: no cover def get_all(self, name: str, failobj: _T = ...) -> Union[List[Any], _T]: """ Return all values associated with a possibly multi-valued key. """ class Distribution: """A Python distribution package.""" @abc.abstractmethod def read_text(self, filename): """Attempt to load metadata file given by the name. :param filename: The name of the file in the distribution info. :return: The text if found, otherwise None. """ @abc.abstractmethod def locate_file(self, path): """ Given a path to a file in this distribution, return a path to it. """ @classmethod def from_name(cls, name): """Return the Distribution for the given package name. :param name: The name of the distribution package to search for. :return: The Distribution instance (or subclass thereof) for the named package, if found. :raises PackageNotFoundError: When the named package's distribution metadata cannot be found. """ for resolver in cls._discover_resolvers(): dists = resolver(DistributionFinder.Context(name=name)) dist = next(iter(dists), None) if dist is not None: return dist else: raise PackageNotFoundError(name) @classmethod def discover(cls, **kwargs): """Return an iterable of Distribution objects for all packages. Pass a ``context`` or pass keyword arguments for constructing a context. :context: A ``DistributionFinder.Context`` object. :return: Iterable of Distribution objects for all packages. """ context = kwargs.pop('context', None) if context and kwargs: raise ValueError("cannot accept context and kwargs") context = context or DistributionFinder.Context(**kwargs) return itertools.chain.from_iterable( resolver(context) for resolver in cls._discover_resolvers() ) @staticmethod def at(path): """Return a Distribution for the indicated metadata path :param path: a string or path-like object :return: a concrete Distribution instance for the path """ return PathDistribution(pathlib.Path(path)) @staticmethod def _discover_resolvers(): """Search the meta_path for resolvers.""" declared = ( getattr(finder, '_catalogue_find_distributions', None) for finder in sys.meta_path ) return filter(None, declared) @classmethod def _local(cls, root='.'): from pep517 import build, meta system = build.compat_system(root) builder = functools.partial( meta.build, source_dir=root, system=system, ) return PathDistribution(zipp.Path(meta.build_as_zip(builder))) @property def metadata(self) -> PackageMetadata: """Return the parsed metadata for this Distribution. The returned object will have keys that name the various bits of metadata. See PEP 566 for details. """ text = ( self.read_text('METADATA') or self.read_text('PKG-INFO') # This last clause is here to support old egg-info files. Its # effect is to just end up using the PathDistribution's self._path # (which points to the egg-info file) attribute unchanged. or self.read_text('') ) return email.message_from_string(text) @property def version(self): """Return the 'Version' metadata for the distribution package.""" return self.metadata['Version'] @property def entry_points(self): return EntryPoint._from_text(self.read_text('entry_points.txt')) @property def files(self): """Files in this distribution. :return: List of PackagePath for this distribution or None Result is `None` if the metadata file that enumerates files (i.e. RECORD for dist-info or SOURCES.txt for egg-info) is missing. Result may be empty if the metadata exists but is empty. """ file_lines = self._read_files_distinfo() or self._read_files_egginfo() def make_file(name, hash=None, size_str=None): result = PackagePath(name) result.hash = FileHash(hash) if hash else None result.size = int(size_str) if size_str else None result.dist = self return result return file_lines and list(starmap(make_file, csv.reader(file_lines))) def _read_files_distinfo(self): """ Read the lines of RECORD """ text = self.read_text('RECORD') return text and text.splitlines() def _read_files_egginfo(self): """ SOURCES.txt might contain literal commas, so wrap each line in quotes. """ text = self.read_text('SOURCES.txt') return text and map('"{}"'.format, text.splitlines()) @property def requires(self): """Generated requirements specified for this Distribution""" reqs = self._read_dist_info_reqs() or self._read_egg_info_reqs() return reqs and list(reqs) def _read_dist_info_reqs(self): return self.metadata.get_all('Requires-Dist') def _read_egg_info_reqs(self): source = self.read_text('requires.txt') return source and self._deps_from_requires_text(source) @classmethod def _deps_from_requires_text(cls, source): section_pairs = cls._read_sections(source.splitlines()) sections = { section: list(map(operator.itemgetter('line'), results)) for section, results in itertools.groupby( section_pairs, operator.itemgetter('section') ) } return cls._convert_egg_info_reqs_to_simple_reqs(sections) @staticmethod def _read_sections(lines): section = None for line in filter(None, lines): section_match = re.match(r'\[(.*)\]$', line) if section_match: section = section_match.group(1) continue yield locals() @staticmethod def _convert_egg_info_reqs_to_simple_reqs(sections): """ Historically, setuptools would solicit and store 'extra' requirements, including those with environment markers, in separate sections. More modern tools expect each dependency to be defined separately, with any relevant extras and environment markers attached directly to that requirement. This method converts the former to the latter. See _test_deps_from_requires_text for an example. """ def make_condition(name): return name and 'extra == "{name}"'.format(name=name) def parse_condition(section): section = section or '' extra, sep, markers = section.partition(':') if extra and markers: markers = '({markers})'.format(markers=markers) conditions = list(filter(None, [markers, make_condition(extra)])) return '; ' + ' and '.join(conditions) if conditions else '' for section, deps in sections.items(): for dep in deps: yield dep + parse_condition(section) class DistributionFinder(MetaPathFinder): """ A MetaPathFinder capable of discovering installed distributions. """ class Context: """ Keyword arguments presented by the caller to ``distributions()`` or ``Distribution.discover()`` to narrow the scope of a search for distributions in all DistributionFinders. Each DistributionFinder may expect any parameters and should attempt to honor the canonical parameters defined below when appropriate. """ name = None """ Specific name for which a distribution finder should match. A name of ``None`` matches all distributions. """ def __init__(self, **kwargs): vars(self).update(kwargs) @property def path(self): """ The path that a distribution finder should search. Typically refers to Python package paths and defaults to ``sys.path``. """ return vars(self).get('path', sys.path) @abc.abstractmethod def _catalogue_find_distributions(self, context=Context()): """ Find distributions. Return an iterable of all Distribution instances capable of loading the metadata for packages matching the ``context``, a DistributionFinder.Context instance. """ class FastPath: """ Micro-optimized class for searching a path for children. """ def __init__(self, root): self.root = str(root) self.base = os.path.basename(self.root).lower() def joinpath(self, child): return pathlib.Path(self.root, child) def children(self): with suppress(Exception): return os.listdir(self.root or '') with suppress(Exception): return self.zip_children() return [] def zip_children(self): zip_path = zipp.Path(self.root) names = zip_path.root.namelist() self.joinpath = zip_path.joinpath return dict.fromkeys(child.split(posixpath.sep, 1)[0] for child in names) def search(self, name): return ( self.joinpath(child) for child in self.children() if name.matches(child, self.base) ) class Prepared: """ A prepared search for metadata on a possibly-named package. """ normalized = None suffixes = '.dist-info', '.egg-info' exact_matches = [''][:0] def __init__(self, name): self.name = name if name is None: return self.normalized = self.normalize(name) self.exact_matches = [self.normalized + suffix for suffix in self.suffixes] @staticmethod def normalize(name): """ PEP 503 normalization plus dashes as underscores. """ return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_') @staticmethod def legacy_normalize(name): """ Normalize the package name as found in the convention in older packaging tools versions and specs. """ return name.lower().replace('-', '_') def matches(self, cand, base): low = cand.lower() pre, ext = os.path.splitext(low) name, sep, rest = pre.partition('-') return ( low in self.exact_matches or ext in self.suffixes and (not self.normalized or name.replace('.', '_') == self.normalized) # legacy case: or self.is_egg(base) and low == 'egg-info' ) def is_egg(self, base): normalized = self.legacy_normalize(self.name or '') prefix = normalized + '-' if normalized else '' versionless_egg_name = normalized + '.egg' if self.name else '' return ( base == versionless_egg_name or base.startswith(prefix) and base.endswith('.egg') ) @install class MetadataPathFinder(NullFinder, DistributionFinder): """A degenerate finder for distribution packages on the file system. This finder supplies only a find_distributions() method for versions of Python that do not have a PathFinder find_distributions(). """ def _catalogue_find_distributions(self, context=DistributionFinder.Context()): """ Find distributions. Return an iterable of all Distribution instances capable of loading the metadata for packages matching ``context.name`` (or all names if ``None`` indicated) along the paths in the list of directories ``context.path``. """ found = self._search_paths(context.name, context.path) return map(PathDistribution, found) @classmethod def _search_paths(cls, name, paths): """Find metadata directories in paths heuristically.""" return itertools.chain.from_iterable( path.search(Prepared(name)) for path in map(FastPath, paths) ) class PathDistribution(Distribution): def __init__(self, path): """Construct a distribution from a path to the metadata directory. :param path: A pathlib.Path or similar object supporting .joinpath(), __div__, .parent, and .read_text(). """ self._path = path def read_text(self, filename): with suppress( FileNotFoundError, IsADirectoryError, KeyError, NotADirectoryError, PermissionError, ): return self._path.joinpath(filename).read_text(encoding='utf-8') read_text.__doc__ = Distribution.read_text.__doc__ def locate_file(self, path): return self._path.parent / path def distribution(distribution_name): """Get the ``Distribution`` instance for the named package. :param distribution_name: The name of the distribution package as a string. :return: A ``Distribution`` instance (or subclass thereof). """ return Distribution.from_name(distribution_name) def distributions(**kwargs): """Get all ``Distribution`` instances in the current environment. :return: An iterable of ``Distribution`` instances. """ return Distribution.discover(**kwargs) def metadata(distribution_name) -> PackageMetadata: """Get the metadata for the named package. :param distribution_name: The name of the distribution package to query. :return: A PackageMetadata containing the parsed metadata. """ return Distribution.from_name(distribution_name).metadata def version(distribution_name): """Get the version string for the named package. :param distribution_name: The name of the distribution package to query. :return: The version string for the package as defined in the package's "Version" metadata key. """ return distribution(distribution_name).version def entry_points(): """Return EntryPoint objects for all installed packages. :return: EntryPoint objects for all installed packages. """ eps = itertools.chain.from_iterable(dist.entry_points for dist in distributions()) by_group = operator.attrgetter('group') ordered = sorted(eps, key=by_group) grouped = itertools.groupby(ordered, by_group) return {group: tuple(eps) for group, eps in grouped} def files(distribution_name): """Return a list of files for the named package. :param distribution_name: The name of the distribution package to query. :return: List of files composing the distribution. """ return distribution(distribution_name).files def requires(distribution_name): """ Return a list of requirements for the named package. :return: An iterator of requirements, suitable for packaging.requirement.Requirement. """ return distribution(distribution_name).requires def packages_distributions() -> Mapping[str, List[str]]: """ Return a mapping of top-level packages to their distributions. >>> pkgs = packages_distributions() >>> all(isinstance(dist, collections.abc.Sequence) for dist in pkgs.values()) True """ pkg_to_dist = collections.defaultdict(list) for dist in distributions(): for pkg in (dist.read_text('top_level.txt') or '').split(): pkg_to_dist[pkg].append(dist.metadata['Name']) return dict(pkg_to_dist) catalogue-2.1.0/catalogue/_importlib_metadata/_compat.py000066400000000000000000000045461426104053400234260ustar00rootroot00000000000000import sys __all__ = ['install', 'NullFinder', 'PyPy_repr', 'Protocol'] try: from typing import Protocol except ImportError: # pragma: no cover """ pytest-mypy complains here because: error: Incompatible import of "Protocol" (imported name has type "typing_extensions._SpecialForm", local name has type "typing._SpecialForm") """ from typing_extensions import Protocol # type: ignore def install(cls): """ Class decorator for installation on sys.meta_path. Adds the backport DistributionFinder to sys.meta_path and attempts to disable the finder functionality of the stdlib DistributionFinder. """ sys.meta_path.append(cls()) disable_stdlib_finder() return cls def disable_stdlib_finder(): """ Give the backport primacy for discovering path-based distributions by monkey-patching the stdlib O_O. See #91 for more background for rationale on this sketchy behavior. """ def matches(finder): return getattr( finder, '__module__', None ) == '_frozen_importlib_external' and hasattr(finder, '_catalogue_find_distributions') for finder in filter(matches, sys.meta_path): # pragma: nocover del finder._catalogue_find_distributions class NullFinder: """ A "Finder" (aka "MetaClassFinder") that never finds any modules, but may find distributions. """ @staticmethod def find_spec(*args, **kwargs): return None # In Python 2, the import system requires finders # to have a find_module() method, but this usage # is deprecated in Python 3 in favor of find_spec(). # For the purposes of this finder (i.e. being present # on sys.meta_path but having no other import # system functionality), the two methods are identical. find_module = find_spec class PyPy_repr: """ Override repr for EntryPoint objects on PyPy to avoid __iter__ access. Ref #97, #102. """ affected = hasattr(sys, 'pypy_version_info') def __compat_repr__(self): # pragma: nocover def make_param(name): value = getattr(self, name) return '{name}={value!r}'.format(**locals()) params = ', '.join(map(make_param, self._fields)) return 'EntryPoint({params})'.format(**locals()) if affected: # pragma: nocover __repr__ = __compat_repr__ del affected catalogue-2.1.0/catalogue/config/000077500000000000000000000000001426104053400166665ustar00rootroot00000000000000catalogue-2.1.0/catalogue/config/__init__.py000066400000000000000000000000261426104053400207750ustar00rootroot00000000000000from .config import * catalogue-2.1.0/catalogue/config/config.py000066400000000000000000001327271426104053400205210ustar00rootroot00000000000000from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping from typing import Iterable, Sequence, cast from types import GeneratorType from dataclasses import dataclass from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH from configparser import InterpolationMissingOptionError, InterpolationSyntaxError from configparser import NoSectionError, NoOptionError, InterpolationDepthError from configparser import ParsingError from pathlib import Path from pydantic import BaseModel, create_model, ValidationError, Extra from pydantic.main import ModelMetaclass from pydantic.fields import ModelField import srsly import catalogue.registry import inspect import io import numpy import copy import re from .util import Decorator # Field used for positional arguments, e.g. [section.*.xyz]. The alias is # required for the schema (shouldn't clash with user-defined arg names) ARGS_FIELD = "*" ARGS_FIELD_ALIAS = "VARIABLE_POSITIONAL_ARGS" # Aliases for fields that would otherwise shadow pydantic attributes. Can be any # string, so we're using name + space so it looks the same in error messages etc. RESERVED_FIELDS = {"validate": "validate\u0020"} # Internal prefix used to mark section references for custom interpolation SECTION_PREFIX = "__SECTION__:" # Values that shouldn't be loaded during interpolation because it'd cause # even explicit string values to be incorrectly parsed as bools/None etc. JSON_EXCEPTIONS = ("true", "false", "null") # Regex to detect whether a value contains a variable VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}") class CustomInterpolation(ExtendedInterpolation): def before_read(self, parser, section, option, value): # If we're dealing with a quoted string as the interpolation value, # make sure we load and unquote it so we don't end up with '"value"' try: json_value = srsly.json_loads(value) if isinstance(json_value, str) and json_value not in JSON_EXCEPTIONS: value = json_value except Exception: pass return super().before_read(parser, section, option, value) def before_get(self, parser, section, option, value, defaults): # Mostly copy-pasted from the built-in configparser implementation. L = [] self.interpolate(parser, option, L, value, section, defaults, 1) return "".join(L) def interpolate(self, parser, option, accum, rest, section, map, depth): # Mostly copy-pasted from the built-in configparser implementation. # We need to overwrite this method so we can add special handling for # block references :( All values produced here should be strings – # we need to wait until the whole config is interpreted anyways so # filling in incomplete values here is pointless. All we need is the # section reference so we can fetch it later. rawval = parser.get(section, option, raw=True, fallback=rest) if depth > MAX_INTERPOLATION_DEPTH: raise InterpolationDepthError(option, section, rawval) while rest: p = rest.find("$") if p < 0: accum.append(rest) return if p > 0: accum.append(rest[:p]) rest = rest[p:] # p is no longer used c = rest[1:2] if c == "$": accum.append("$") rest = rest[2:] elif c == "{": # We want to treat both ${a:b} and ${a.b} the same m = self._KEYCRE.match(rest) if m is None: err = f"bad interpolation variable reference {rest}" raise InterpolationSyntaxError(option, section, err) orig_var = m.group(1) path = orig_var.replace(":", ".").rsplit(".", 1) rest = rest[m.end() :] sect = section opt = option try: if len(path) == 1: opt = parser.optionxform(path[0]) if opt in map: v = map[opt] else: # We have block reference, store it as a special key section_name = parser[parser.optionxform(path[0])]._name v = self._get_section_name(section_name) elif len(path) == 2: sect = path[0] opt = parser.optionxform(path[1]) fallback = "__FALLBACK__" v = parser.get(sect, opt, raw=True, fallback=fallback) # If a variable doesn't exist, try again and treat the # reference as a section if v == fallback: v = self._get_section_name(parser[f"{sect}.{opt}"]._name) else: err = f"More than one ':' found: {rest}" raise InterpolationSyntaxError(option, section, err) except (KeyError, NoSectionError, NoOptionError): raise InterpolationMissingOptionError( option, section, rawval, orig_var ) from None if "$" in v: new_map = dict(parser.items(sect, raw=True)) self.interpolate(parser, opt, accum, v, sect, new_map, depth + 1) else: accum.append(v) else: err = "'$' must be followed by '$' or '{', " "found: %r" % (rest,) raise InterpolationSyntaxError(option, section, err) def _get_section_name(self, name: str) -> str: """Generate the name of a section. Note that we use a quoted string here so we can use section references within lists and load the list as JSON. Since section references can't be used within strings, we don't need the quoted vs. unquoted distinction like we do for variables. Examples (assuming section = {"foo": 1}): - value: ${section.foo} -> value: 1 - value: "hello ${section.foo}" -> value: "hello 1" - value: ${section} -> value: {"foo": 1} - value: "${section}" -> value: {"foo": 1} - value: "hello ${section}" -> invalid """ return f'"{SECTION_PREFIX}{name}"' def get_configparser(interpolate: bool = True): config = ConfigParser(interpolation=CustomInterpolation() if interpolate else None) # Preserve case of keys: https://stackoverflow.com/a/1611877/6400719 config.optionxform = str # type: ignore return config class Config(dict): """This class holds the model and training configuration and can load and save the TOML-style configuration format from/to a string, file or bytes. The Config class is a subclass of dict and uses Python's ConfigParser under the hood. """ is_interpolated: bool def __init__( self, data: Optional[Union[Dict[str, Any], "ConfigParser", "Config"]] = None, *, is_interpolated: Optional[bool] = None, section_order: Optional[List[str]] = None, ) -> None: """Initialize a new Config object with optional data.""" dict.__init__(self) if data is None: data = {} if not isinstance(data, (dict, Config, ConfigParser)): raise ValueError( f"Can't initialize Config with data. Expected dict, Config or " f"ConfigParser but got: {type(data)}" ) # Whether the config has been interpolated. We can use this to check # whether we need to interpolate again when it's resolved. We assume # that a config is interpolated by default. if is_interpolated is not None: self.is_interpolated = is_interpolated elif isinstance(data, Config): self.is_interpolated = data.is_interpolated else: self.is_interpolated = True if section_order is not None: self.section_order = section_order elif isinstance(data, Config): self.section_order = data.section_order else: self.section_order = [] # Update with data self.update(self._sort(data)) def interpolate(self) -> "Config": """Interpolate a config. Returns a copy of the object.""" # This is currently the most effective way because we need our custom # to_str logic to run in order to re-serialize the values so we can # interpolate them again. ConfigParser.read_dict will just call str() # on all values, which isn't enough. return Config().from_str(self.to_str()) def interpret_config(self, config: "ConfigParser") -> None: """Interpret a config, parse nested sections and parse the values as JSON. Mostly used internally and modifies the config in place. """ self._validate_sections(config) # Sort sections by depth, so that we can iterate breadth-first. This # allows us to check that we're not expanding an undefined block. get_depth = lambda item: len(item[0].split(".")) for section, values in sorted(config.items(), key=get_depth): if section == "DEFAULT": # Skip [DEFAULT] section so it doesn't cause validation error continue parts = section.split(".") node = self for part in parts[:-1]: if part == "*": node = node.setdefault(part, {}) elif part not in node: err_title = f"Error parsing config section. Perhaps a section name is wrong?" err = [{"loc": parts, "msg": f"Section '{part}' is not defined"}] raise ConfigValidationError( config=self, errors=err, title=err_title ) else: node = node[part] if not isinstance(node, dict): # Happens if both value *and* subsection were defined for a key err = [{"loc": parts, "msg": "found conflicting values"}] err_cfg = f"{self}\n{({part: dict(values)})}" raise ConfigValidationError(config=err_cfg, errors=err) # Set the default section node = node.setdefault(parts[-1], {}) if not isinstance(node, dict): # Happens if both value *and* subsection were defined for a key err = [{"loc": parts, "msg": "found conflicting values"}] err_cfg = f"{self}\n{({part: dict(values)})}" raise ConfigValidationError(config=err_cfg, errors=err) try: keys_values = list(values.items()) except InterpolationMissingOptionError as e: raise ConfigValidationError(desc=f"{e}") from None for key, value in keys_values: config_v = config.get(section, key) node[key] = self._interpret_value(config_v) self.replace_section_refs(self) def replace_section_refs( self, config: Union[Dict[str, Any], "Config"], parent: str = "" ) -> None: """Replace references to section blocks in the final config.""" for key, value in config.items(): key_parent = f"{parent}.{key}".strip(".") if isinstance(value, dict): self.replace_section_refs(value, parent=key_parent) elif isinstance(value, list): config[key] = [ self._get_section_ref(v, parent=[parent, key]) for v in value ] else: config[key] = self._get_section_ref(value, parent=[parent, key]) def _interpret_value(self, value: Any) -> Any: """Interpret a single config value.""" result = try_load_json(value) # If value is a string and it contains a variable, use original value # (not interpreted string, which could lead to double quotes: # ${x.y} -> "${x.y}" -> "'${x.y}'"). Make sure to check it's a string, # so we're not keeping lists as strings. # NOTE: This currently can't handle uninterpolated values like [${x.y}]! if isinstance(result, str) and VARIABLE_RE.search(value): result = value if isinstance(result, list): return [self._interpret_value(v) for v in result] return result def _get_section_ref(self, value: Any, *, parent: List[str] = []) -> Any: """Get a single section reference.""" if isinstance(value, str) and value.startswith(f'"{SECTION_PREFIX}'): value = try_load_json(value) if isinstance(value, str) and value.startswith(SECTION_PREFIX): parts = value.replace(SECTION_PREFIX, "").split(".") result = self for item in parts: try: result = result[item] except (KeyError, TypeError): # This should never happen err_title = "Error parsing reference to config section" err_msg = f"Section '{'.'.join(parts)}' is not defined" err = [{"loc": parts, "msg": err_msg}] raise ConfigValidationError( config=self, errors=err, title=err_title ) from None return result elif isinstance(value, str) and SECTION_PREFIX in value: # String value references a section (either a dict or return # value of promise). We can't allow this, since variables are # always interpolated *before* configs are resolved. err_desc = ( "Can't reference whole sections or return values of function " "blocks inside a string or list\n\nYou can change your variable to " "reference a value instead. Keep in mind that it's not " "possible to interpolate the return value of a registered " "function, since variables are interpolated when the config " "is loaded, and registered functions are resolved afterwards." ) err = [{"loc": parent, "msg": "uses section variable in string or list"}] raise ConfigValidationError(errors=err, desc=err_desc) return value def copy(self) -> "Config": """Deepcopy the config.""" try: config = copy.deepcopy(self) except Exception as e: raise ValueError(f"Couldn't deep-copy config: {e}") from e return Config( config, is_interpolated=self.is_interpolated, section_order=self.section_order, ) def merge( self, updates: Union[Dict[str, Any], "Config"], remove_extra: bool = False ) -> "Config": """Deep merge the config with updates, using current as defaults.""" defaults = self.copy() updates = Config(updates).copy() merged = deep_merge_configs(updates, defaults, remove_extra=remove_extra) return Config( merged, is_interpolated=defaults.is_interpolated and updates.is_interpolated, section_order=defaults.section_order, ) def _sort( self, data: Union["Config", "ConfigParser", Dict[str, Any]] ) -> Dict[str, Any]: """Sort sections using the currently defined sort order. Sort sections by index on section order, if available, then alphabetic, and account for subsections, which should always follow their parent. """ sort_map = {section: i for i, section in enumerate(self.section_order)} sort_key = lambda x: ( sort_map.get(x[0].split(".")[0], len(sort_map)), _mask_positional_args(x[0]), ) return dict(sorted(data.items(), key=sort_key)) def _set_overrides(self, config: "ConfigParser", overrides: Dict[str, Any]) -> None: """Set overrides in the ConfigParser before config is interpreted.""" err_title = "Error parsing config overrides" for key, value in overrides.items(): err_msg = "not a section value that can be overwritten" err = [{"loc": key.split("."), "msg": err_msg}] if "." not in key: raise ConfigValidationError(errors=err, title=err_title) section, option = key.rsplit(".", 1) # Check for section and accept if option not in config[section] if section not in config: raise ConfigValidationError(errors=err, title=err_title) config.set(section, option, try_dump_json(value, overrides)) def _validate_sections(self, config: "ConfigParser") -> None: # If the config defines top-level properties that are not sections (e.g. # if config was constructed from dict), those values would be added as # [DEFAULTS] and included in *every other section*. This is usually not # what we want and it can lead to very confusing results. default_section = config.defaults() if default_section: err_title = "Found config values without a top-level section" err_msg = "not part of a section" err = [{"loc": [k], "msg": err_msg} for k in default_section] raise ConfigValidationError(errors=err, title=err_title) def from_str( self, text: str, *, interpolate: bool = True, overrides: Dict[str, Any] = {} ) -> "Config": """Load the config from a string.""" config = get_configparser(interpolate=interpolate) if overrides: config = get_configparser(interpolate=False) try: config.read_string(text) except ParsingError as e: desc = f"Make sure the sections and values are formatted correctly.\n\n{e}" raise ConfigValidationError(desc=desc) from None config._sections = self._sort(config._sections) self._set_overrides(config, overrides) self.clear() self.interpret_config(config) if overrides and interpolate: # do the interpolation. Avoids recursion because the new call from_str call will have overrides as empty self = self.interpolate() self.is_interpolated = interpolate return self def to_str(self, *, interpolate: bool = True) -> str: """Write the config to a string.""" flattened = get_configparser(interpolate=interpolate) queue: List[Tuple[tuple, "Config"]] = [(tuple(), self)] for path, node in queue: section_name = ".".join(path) is_kwarg = path and path[-1] != "*" if is_kwarg and not flattened.has_section(section_name): # Always create sections for non-'*' sections, not only if # they have leaf entries, as we don't want to expand # blocks that are undefined flattened.add_section(section_name) for key, value in node.items(): if hasattr(value, "items"): # Reference to a function with no arguments, serialize # inline as a dict and don't create new section if catalogue_registry.is_promise(value) and len(value) == 1 and is_kwarg: flattened.set(section_name, key, try_dump_json(value, node)) else: queue.append((path + (key,), value)) else: flattened.set(section_name, key, try_dump_json(value, node)) # Order so subsection follow parent (not all sections, then all subs etc.) flattened._sections = self._sort(flattened._sections) self._validate_sections(flattened) string_io = io.StringIO() flattened.write(string_io) return string_io.getvalue().strip() def to_bytes(self, *, interpolate: bool = True) -> bytes: """Serialize the config to a byte string.""" return self.to_str(interpolate=interpolate).encode("utf8") def from_bytes( self, bytes_data: bytes, *, interpolate: bool = True, overrides: Dict[str, Any] = {}, ) -> "Config": """Load the config from a byte string.""" return self.from_str( bytes_data.decode("utf8"), interpolate=interpolate, overrides=overrides ) def to_disk(self, path: Union[str, Path], *, interpolate: bool = True): """Serialize the config to a file.""" path = Path(path) if isinstance(path, str) else path with path.open("w", encoding="utf8") as file_: file_.write(self.to_str(interpolate=interpolate)) def from_disk( self, path: Union[str, Path], *, interpolate: bool = True, overrides: Dict[str, Any] = {}, ) -> "Config": """Load config from a file.""" path = Path(path) if isinstance(path, str) else path with path.open("r", encoding="utf8") as file_: text = file_.read() return self.from_str(text, interpolate=interpolate, overrides=overrides) def _mask_positional_args(name: str) -> List[Optional[str]]: """Create a section name representation that masks names of positional arguments to retain their order in sorts.""" stable_name = cast(List[Optional[str]], name.split(".")) # Remove names of sections that are a positional argument. for i in range(1, len(stable_name)): if stable_name[i - 1] == "*": stable_name[i] = None return stable_name def try_load_json(value: str) -> Any: """Load a JSON string if possible, otherwise default to original value.""" try: return srsly.json_loads(value) except Exception: return value def try_dump_json(value: Any, data: Union[Dict[str, dict], Config, str] = "") -> str: """Dump a config value as JSON and output user-friendly error if it fails.""" # Special case if we have a variable: it's already a string so don't dump # to preserve ${x:y} vs. "${x:y}" if isinstance(value, str) and VARIABLE_RE.search(value): return value if isinstance(value, str) and value.replace(".", "", 1).isdigit(): # Work around values that are strings but numbers value = f'"{value}"' try: return srsly.json_dumps(value) except Exception as e: err_msg = ( f"Couldn't serialize config value of type {type(value)}: {e}. Make " f"sure all values in your config are JSON-serializable. If you want " f"to include Python objects, use a registered function that returns " f"the object instead." ) raise ConfigValidationError(config=data, desc=err_msg) from e def deep_merge_configs( config: Union[Dict[str, Any], Config], defaults: Union[Dict[str, Any], Config], *, remove_extra: bool = False, ) -> Union[Dict[str, Any], Config]: """Deep merge two configs.""" if remove_extra: # Filter out values in the original config that are not in defaults keys = list(config.keys()) for key in keys: if key not in defaults: del config[key] for key, value in defaults.items(): if isinstance(value, dict): node = config.setdefault(key, {}) if not isinstance(node, dict): continue value_promises = [k for k in value if k.startswith("@")] value_promise = value_promises[0] if value_promises else None node_promises = [k for k in node if k.startswith("@")] if node else [] node_promise = node_promises[0] if node_promises else None # We only update the block from defaults if it refers to the same # registered function if ( value_promise and node_promise and ( value_promise in node and node[value_promise] != value[value_promise] ) ): continue if node_promise and ( node_promise not in value or node[node_promise] != value[node_promise] ): continue defaults = deep_merge_configs(node, value, remove_extra=remove_extra) elif key not in config: config[key] = value return config class ConfigValidationError(ValueError): def __init__( self, *, config: Optional[Union[Config, Dict[str, Dict[str, Any]], str]] = None, errors: Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]] = tuple(), title: Optional[str] = "Config validation error", desc: Optional[str] = None, parent: Optional[str] = None, show_config: bool = True, ) -> None: """Custom error for validating configs. config (Union[Config, Dict[str, Dict[str, Any]], str]): The config the validation error refers to. errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]): A list of errors as dicts with keys "loc" (list of strings describing the path of the value), "msg" (validation message to show) and optional "type" (mostly internals). Same format as produced by pydantic's validation error (e.errors()). title (str): The error title. desc (str): Optional error description, displayed below the title. parent (str): Optional parent to use as prefix for all error locations. For example, parent "element" will result in "element -> a -> b". show_config (bool): Whether to print the whole config with the error. ATTRIBUTES: config (Union[Config, Dict[str, Dict[str, Any]], str]): The config. errors (Iterable[Dict[str, Any]]): The errors. error_types (Set[str]): All "type" values defined in the errors, if available. This is most relevant for the pydantic errors that define types like "type_error.integer". This attribute makes it easy to check if a config validation error includes errors of a certain type, e.g. to log additional information or custom help messages. title (str): The title. desc (str): The description. parent (str): The parent. show_config (bool): Whether to show the config. text (str): The formatted error text. """ self.config = config self.errors = errors self.title = title self.desc = desc self.parent = parent self.show_config = show_config self.error_types = set() for error in self.errors: err_type = error.get("type") if err_type: self.error_types.add(err_type) self.text = self._format() ValueError.__init__(self, self.text) @classmethod def from_error( cls, err: "ConfigValidationError", title: Optional[str] = None, desc: Optional[str] = None, parent: Optional[str] = None, show_config: Optional[bool] = None, ) -> "ConfigValidationError": """Create a new ConfigValidationError based on an existing error, e.g. to re-raise it with different settings. If no overrides are provided, the values from the original error are used. err (ConfigValidationError): The original error. title (str): Overwrite error title. desc (str): Overwrite error description. parent (str): Overwrite error parent. show_config (bool): Overwrite whether to show config. RETURNS (ConfigValidationError): The new error. """ return cls( config=err.config, errors=err.errors, title=title if title is not None else err.title, desc=desc if desc is not None else err.desc, parent=parent if parent is not None else err.parent, show_config=show_config if show_config is not None else err.show_config, ) def _format(self) -> str: """Format the error message.""" loc_divider = "->" data = [] for error in self.errors: err_loc = f" {loc_divider} ".join([str(p) for p in error.get("loc", [])]) if self.parent: err_loc = f"{self.parent} {loc_divider} {err_loc}" data.append((err_loc, error.get("msg"))) result = [] if self.title: result.append(self.title) if self.desc: result.append(self.desc) if data: result.append("\n".join([f"{entry[0]}\t{entry[1]}" for entry in data])) if self.config and self.show_config: result.append(f"{self.config}") return "\n\n" + "\n".join(result) def alias_generator(name: str) -> str: """Generate field aliases in promise schema.""" # Underscore fields are not allowed in model, so use alias if name == ARGS_FIELD_ALIAS: return ARGS_FIELD # Auto-alias fields that shadow base model attributes if name in RESERVED_FIELDS: return RESERVED_FIELDS[name] return name def copy_model_field(field: ModelField, type_: Any) -> ModelField: """Copy a model field and assign a new type, e.g. to accept an Any type even though the original value is typed differently. """ return ModelField( name=field.name, type_=type_, class_validators=field.class_validators, model_config=field.model_config, default=field.default, default_factory=field.default_factory, required=field.required, ) class EmptySchema(BaseModel): class Config: extra = "allow" arbitrary_types_allowed = True class _PromiseSchemaConfig: extra = "forbid" arbitrary_types_allowed = True alias_generator = alias_generator @dataclass class Promise: registry: str name: str args: List[str] kwargs: Dict[str, Any] class catalogue_registry(object): @classmethod def create(cls, registry_name: str, entry_points: bool = False) -> None: """Create a new custom registry.""" if hasattr(cls, registry_name): raise ValueError(f"Registry '{registry_name}' already exists") reg: Decorator = catalogue.registry.create( "catalogue", registry_name, entry_points=entry_points ) setattr(cls, registry_name, reg) @classmethod def has(cls, registry_name: str, func_name: str) -> bool: """Check whether a function is available in a registry.""" if not hasattr(cls, registry_name): return False reg = getattr(cls, registry_name) return func_name in reg @classmethod def get(cls, registry_name: str, func_name: str) -> Callable: """Get a registered function from a given registry.""" if not hasattr(cls, registry_name): raise ValueError(f"Unknown registry: '{registry_name}'") reg = getattr(cls, registry_name) func = reg.get(func_name) if func is None: raise ValueError(f"Could not find '{func_name}' in '{registry_name}'") return func @classmethod def resolve( cls, config: Union[Config, Dict[str, Dict[str, Any]]], *, schema: Type[BaseModel] = EmptySchema, overrides: Dict[str, Any] = {}, validate: bool = True, ) -> Dict[str, Any]: resolved, _ = cls._make( config, schema=schema, overrides=overrides, validate=validate, resolve=True ) return resolved @classmethod def fill( cls, config: Union[Config, Dict[str, Dict[str, Any]]], *, schema: Type[BaseModel] = EmptySchema, overrides: Dict[str, Any] = {}, validate: bool = True, ): _, filled = cls._make( config, schema=schema, overrides=overrides, validate=validate, resolve=False ) return filled @classmethod def _make( cls, config: Union[Config, Dict[str, Dict[str, Any]]], *, schema: Type[BaseModel] = EmptySchema, overrides: Dict[str, Any] = {}, resolve: bool = True, validate: bool = True, ) -> Tuple[Dict[str, Any], Config]: """Unpack a config dictionary and create two versions of the config: a resolved version with objects from the registry created recursively, and a filled version with all references to registry functions left intact, but filled with all values and defaults based on the type annotations. If validate=True, the config will be validated against the type annotations of the registered functions referenced in the config (if available) and/or the schema (if available). """ # Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}} # Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0} if cls.is_promise(config): err_msg = "The top-level config object can't be a reference to a registered function." raise ConfigValidationError(config=config, errors=[{"msg": err_msg}]) # If a Config was loaded with interpolate=False, we assume it needs to # be interpolated first, otherwise we take it at face value is_interpolated = not isinstance(config, Config) or config.is_interpolated section_order = config.section_order if isinstance(config, Config) else None orig_config = config if not is_interpolated: config = Config(orig_config).interpolate() filled, _, resolved = cls._fill( config, schema, validate=validate, overrides=overrides, resolve=resolve ) filled = Config(filled, section_order=section_order) # Check that overrides didn't include invalid properties not in config if validate: cls._validate_overrides(filled, overrides) # Merge the original config back to preserve variables if we started # with a config that wasn't interpolated. Here, we prefer variables to # allow auto-filling a non-interpolated config without destroying # variable references. if not is_interpolated: filled = filled.merge( Config(orig_config, is_interpolated=False), remove_extra=True ) return dict(resolved), filled @classmethod def _fill( cls, config: Union[Config, Dict[str, Dict[str, Any]]], schema: Type[BaseModel] = EmptySchema, *, validate: bool = True, resolve: bool = True, parent: str = "", overrides: Dict[str, Dict[str, Any]] = {}, ) -> Tuple[ Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any] ]: """Build three representations of the config: 1. All promises are preserved (just like config user would provide). 2. Promises are replaced by their return values. This is the validation copy and will be parsed by pydantic. It lets us include hacks to work around problems (e.g. handling of generators). 3. Final copy with promises replaced by their return values. """ filled: Dict[str, Any] = {} validation: Dict[str, Any] = {} final: Dict[str, Any] = {} for key, value in config.items(): # If the field name is reserved, we use its alias for validation v_key = RESERVED_FIELDS.get(key, key) key_parent = f"{parent}.{key}".strip(".") if key_parent in overrides: value = overrides[key_parent] config[key] = value if cls.is_promise(value): if key in schema.__fields__ and not resolve: # If we're not resolving the config, make sure that the field # expecting the promise is typed Any so it doesn't fail # validation if it doesn't receive the function return value field = schema.__fields__[key] schema.__fields__[key] = copy_model_field(field, Any) promise_schema = cls.make_promise_schema(value, resolve=resolve) filled[key], validation[v_key], final[key] = cls._fill( value, promise_schema, validate=validate, resolve=resolve, parent=key_parent, overrides=overrides, ) reg_name, func_name = cls.get_constructor(final[key]) args, kwargs = cls.parse_args(final[key]) if resolve: # Call the function and populate the field value. We can't # just create an instance of the type here, since this # wouldn't work for generics / more complex custom types getter = cls.get(reg_name, func_name) # We don't want to try/except this and raise our own error # here, because we want the traceback if the function fails. getter_result = getter(*args, **kwargs) else: # We're not resolving and calling the function, so replace # the getter_result with a Promise class getter_result = Promise( registry=reg_name, name=func_name, args=args, kwargs=kwargs ) validation[v_key] = getter_result final[key] = getter_result if isinstance(validation[v_key], GeneratorType): # If value is a generator we can't validate type without # consuming it (which doesn't work if it's infinite – see # schedule for examples). So we skip it. validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema if key in schema.__fields__: field = schema.__fields__[key] field_type = field.type_ if not isinstance(field.type_, ModelMetaclass): # If we don't have a pydantic schema and just a type field_type = EmptySchema filled[key], validation[v_key], final[key] = cls._fill( value, field_type, validate=validate, resolve=resolve, parent=key_parent, overrides=overrides, ) if key == ARGS_FIELD and isinstance(validation[v_key], dict): # If the value of variable positional args is a dict (e.g. # created via config blocks), only use its values validation[v_key] = list(validation[v_key].values()) final[key] = list(final[key].values()) else: filled[key] = value # Prevent pydantic from consuming generator if part of a union validation[v_key] = ( value if not isinstance(value, GeneratorType) else [] ) final[key] = value # Now that we've filled in all of the promises, update with defaults # from schema, and validate if validation is enabled exclude = [] if validate: try: result = schema.parse_obj(validation) except ValidationError as e: raise ConfigValidationError( config=config, errors=e.errors(), parent=parent ) from None else: # Same as parse_obj, but without validation result = schema.construct(**validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything if schema.Config.extra in (Extra.forbid, Extra.ignore): fields = schema.__fields__.keys() exclude = [k for k in result.__fields_set__ if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) validation.update(result.dict(exclude=exclude_validation)) filled, final = cls._update_from_parsed(validation, filled, final) if exclude: filled = {k: v for k, v in filled.items() if k not in exclude} validation = {k: v for k, v in validation.items() if k not in exclude} final = {k: v for k, v in final.items() if k not in exclude} return filled, validation, final @classmethod def _update_from_parsed( cls, validation: Dict[str, Any], filled: Dict[str, Any], final: Dict[str, Any] ): """Update the final result with the parsed config like converted values recursively. """ for key, value in validation.items(): if key in RESERVED_FIELDS.values(): continue # skip aliases for reserved fields if key not in filled: filled[key] = value if key not in final: final[key] = value if isinstance(value, dict): filled[key], final[key] = cls._update_from_parsed( value, filled[key], final[key] ) # Update final config with parsed value if they're not equal (in # value and in type) but not if it's a generator because we had to # replace that to validate it correctly elif key == ARGS_FIELD: continue # don't substitute if list of positional args elif isinstance(value, numpy.ndarray): # check numpy first, just in case final[key] = value elif ( value != final[key] or not isinstance(type(value), type(final[key])) ) and not isinstance(final[key], GeneratorType): final[key] = value return filled, final @classmethod def _validate_overrides(cls, filled: Config, overrides: Dict[str, Any]): """Validate overrides against a filled config to make sure there are no references to properties that don't exist and weren't used.""" error_msg = "Invalid override: config value doesn't exist" errors = [] for override_key in overrides.keys(): if not cls._is_in_config(override_key, filled): errors.append({"msg": error_msg, "loc": [override_key]}) if errors: raise ConfigValidationError(config=filled, errors=errors) @classmethod def _is_in_config(cls, prop: str, config: Union[Dict[str, Any], Config]): """Check whether a nested config property like "section.subsection.key" is in a given config.""" tree = prop.split(".") obj = dict(config) while tree: key = tree.pop(0) if isinstance(obj, dict) and key in obj: obj = obj[key] else: return False return True @classmethod def is_promise(cls, obj: Any) -> bool: """Check whether an object is a "promise", i.e. contains a reference to a registered function (via a key starting with `"@"`. """ if not hasattr(obj, "keys"): return False id_keys = [k for k in obj.keys() if k.startswith("@")] if len(id_keys): return True return False @classmethod def get_constructor(cls, obj: Dict[str, Any]) -> Tuple[str, str]: id_keys = [k for k in obj.keys() if k.startswith("@")] if len(id_keys) != 1: err_msg = f"A block can only contain one function registry reference. Got: {id_keys}" raise ConfigValidationError(config=obj, errors=[{"msg": err_msg}]) else: key = id_keys[0] value = obj[key] return (key[1:], value) @classmethod def parse_args(cls, obj: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]: args = [] kwargs = {} for key, value in obj.items(): if not key.startswith("@"): if key == ARGS_FIELD: args = value elif key in RESERVED_FIELDS.values(): continue else: kwargs[key] = value return args, kwargs @classmethod def make_promise_schema( cls, obj: Dict[str, Any], *, resolve: bool = True ) -> Type[BaseModel]: """Create a schema for a promise dict (referencing a registry function) by inspecting the function signature. """ reg_name, func_name = cls.get_constructor(obj) if not resolve and not cls.has(reg_name, func_name): return EmptySchema func = cls.get(reg_name, func_name) # Read the argument annotations and defaults from the function signature id_keys = [k for k in obj.keys() if k.startswith("@")] sig_args: Dict[str, Any] = {id_keys[0]: (str, ...)} for param in inspect.signature(func).parameters.values(): # If no annotation is specified assume it's anything annotation = param.annotation if param.annotation != param.empty else Any # If no default value is specified assume that it's required default = param.default if param.default != param.empty else ... # Handle spread arguments and use their annotation as Sequence[whatever] if param.kind == param.VAR_POSITIONAL: spread_annot = Sequence[annotation] # type: ignore sig_args[ARGS_FIELD_ALIAS] = (spread_annot, default) else: name = RESERVED_FIELDS.get(param.name, param.name) sig_args[name] = (annotation, default) sig_args["__config__"] = _PromiseSchemaConfig return create_model("ArgModel", **sig_args) __all__ = ["Config", "catalogue_registry", "ConfigValidationError"] catalogue-2.1.0/catalogue/config/util.py000066400000000000000000000030161426104053400202150ustar00rootroot00000000000000import contextlib import functools import shutil import sys import tempfile from pathlib import Path from typing import TypeVar, Callable, Any, Iterator if sys.version_info < (3, 8): # Ignoring type for mypy to avoid "Incompatible import" error (https://github.com/python/mypy/issues/4427). from typing_extensions import Protocol # type: ignore else: from typing import Protocol _DIn = TypeVar("_DIn") class Decorator(Protocol): """Protocol to mark a function as returning its child with identical signature.""" def __call__(self, name: str) -> Callable[[_DIn], _DIn]: ... # This is how functools.partials seems to do it, too, to retain the return type PartialT = TypeVar("PartialT") def partial( func: Callable[..., PartialT], *args: Any, **kwargs: Any ) -> Callable[..., PartialT]: """Wrapper around functools.partial that retains docstrings and can include other workarounds if needed. """ partial_func = functools.partial(func, *args, **kwargs) partial_func.__doc__ = func.__doc__ return partial_func class Generator(Iterator): """Custom generator type. Used to annotate function arguments that accept generators so they can be validated by pydantic (which doesn't support iterators/iterables otherwise). """ @classmethod def __get_validators__(cls): yield cls.validate @classmethod def validate(cls, v): if not hasattr(v, "__iter__") and not hasattr(v, "__next__"): raise TypeError("not a valid iterator") return v catalogue-2.1.0/catalogue/registry.py000066400000000000000000000210061426104053400176420ustar00rootroot00000000000000from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union import inspect try: # Python 3.8 import importlib.metadata as importlib_metadata except ImportError: from . import _importlib_metadata as importlib_metadata # type: ignore # Only ever call this once for performance reasons AVAILABLE_ENTRY_POINTS = importlib_metadata.entry_points() # type: ignore # This is where functions will be registered REGISTRY: Dict[Tuple[str, ...], Any] = {} InFunc = TypeVar("InFunc") def create(*namespace: str, entry_points: bool = False) -> "Registry": """Create a new registry. *namespace (str): The namespace, e.g. "spacy" or "spacy", "architectures". entry_points (bool): Accept registered functions from entry points. RETURNS (Registry): The Registry object. """ if check_exists(*namespace): raise RegistryError(f"Namespace already exists: {namespace}") return Registry(namespace, entry_points=entry_points) class Registry(object): def __init__(self, namespace: Sequence[str], entry_points: bool = False) -> None: """Initialize a new registry. namespace (Sequence[str]): The namespace. entry_points (bool): Whether to also check for entry points. """ self.namespace = namespace self.entry_point_namespace = "_".join(namespace) self.entry_points = entry_points def __contains__(self, name: str) -> bool: """Check whether a name is in the registry. name (str): The name to check. RETURNS (bool): Whether the name is in the registry. """ namespace = tuple(list(self.namespace) + [name]) has_entry_point = self.entry_points and self.get_entry_point(name) return has_entry_point or namespace in REGISTRY def __call__( self, name: str, func: Optional[Any] = None ) -> Callable[[InFunc], InFunc]: """Register a function for a given namespace. Same as Registry.register. name (str): The name to register under the namespace. func (Any): Optional function to register (if not used as decorator). RETURNS (Callable): The decorator. """ return self.register(name, func=func) def register( self, name: str, *, func: Optional[Any] = None ) -> Callable[[InFunc], InFunc]: """Register a function for a given namespace. name (str): The name to register under the namespace. func (Any): Optional function to register (if not used as decorator). RETURNS (Callable): The decorator. """ def do_registration(func): _set(list(self.namespace) + [name], func) return func if func is not None: return do_registration(func) return do_registration def get(self, name: str) -> Any: """Get the registered function for a given name. name (str): The name. RETURNS (Any): The registered function. """ if self.entry_points: from_entry_point = self.get_entry_point(name) if from_entry_point: return from_entry_point namespace = list(self.namespace) + [name] if not check_exists(*namespace): current_namespace = " -> ".join(self.namespace) available = ", ".join(sorted(self.get_all().keys())) or "none" raise RegistryError( f"Cant't find '{name}' in registry {current_namespace}. Available names: {available}" ) return _get(namespace) def get_all(self) -> Dict[str, Any]: """Get a all functions for a given namespace. namespace (Tuple[str]): The namespace to get. RETURNS (Dict[str, Any]): The functions, keyed by name. """ global REGISTRY result = {} if self.entry_points: result.update(self.get_entry_points()) for keys, value in REGISTRY.copy().items(): if len(self.namespace) == len(keys) - 1 and all( self.namespace[i] == keys[i] for i in range(len(self.namespace)) ): result[keys[-1]] = value return result def get_entry_points(self) -> Dict[str, Any]: """Get registered entry points from other packages for this namespace. RETURNS (Dict[str, Any]): Entry points, keyed by name. """ result = {} for entry_point in AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, []): result[entry_point.name] = entry_point.load() return result def get_entry_point(self, name: str, default: Optional[Any] = None) -> Any: """Check if registered entry point is available for a given name in the namespace and load it. Otherwise, return the default value. name (str): Name of entry point to load. default (Any): The default value to return. RETURNS (Any): The loaded entry point or the default value. """ for entry_point in AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, []): if entry_point.name == name: return entry_point.load() return default def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]: """Find the information about a registered function, including the module and path to the file it's defined in, the line number and the docstring, if available. name (str): Name of the registered function. RETURNS (Dict[str, Optional[Union[str, int]]]): The function info. """ func = self.get(name) module = inspect.getmodule(func) # These calls will fail for Cython modules so we need to work around them line_no: Optional[int] = None file_name: Optional[str] = None try: _, line_no = inspect.getsourcelines(func) file_name = inspect.getfile(func) except (TypeError, ValueError): pass docstring = inspect.getdoc(func) return { "module": module.__name__ if module else None, "file": file_name, "line_no": line_no, "docstring": inspect.cleandoc(docstring) if docstring else None, } def check_exists(*namespace: str) -> bool: """Check if a namespace exists. *namespace (str): The namespace. RETURNS (bool): Whether the namespace exists. """ return namespace in REGISTRY def _get(namespace: Sequence[str]) -> Any: """Get the value for a given namespace. namespace (Sequence[str]): The namespace. RETURNS (Any): The value for the namespace. """ global REGISTRY if not all(isinstance(name, str) for name in namespace): raise ValueError( f"Invalid namespace. Expected tuple of strings, but got: {namespace}" ) namespace = tuple(namespace) if namespace not in REGISTRY: raise RegistryError(f"Can't get namespace {namespace} (not in registry)") return REGISTRY[namespace] def _get_all(namespace: Sequence[str]) -> Dict[Tuple[str, ...], Any]: """Get all matches for a given namespace, e.g. ("a", "b", "c") and ("a", "b") for namespace ("a", "b"). namespace (Sequence[str]): The namespace. RETURNS (Dict[Tuple[str], Any]): All entries for the namespace, keyed by their full namespaces. """ global REGISTRY result = {} for keys, value in REGISTRY.copy().items(): if len(namespace) <= len(keys) and all( namespace[i] == keys[i] for i in range(len(namespace)) ): result[keys] = value return result def _set(namespace: Sequence[str], func: Any) -> None: """Set a value for a given namespace. namespace (Sequence[str]): The namespace. func (Callable): The value to set. """ global REGISTRY REGISTRY[tuple(namespace)] = func def _remove(namespace: Sequence[str]) -> Any: """Remove a value for a given namespace. namespace (Sequence[str]): The namespace. RETURNS (Any): The removed value. """ global REGISTRY namespace = tuple(namespace) if namespace not in REGISTRY: raise RegistryError(f"Can't get namespace {namespace} (not in registry)") removed = REGISTRY[namespace] del REGISTRY[namespace] return removed def _empty_registry() -> None: """ Empties REGISTRY completely. """ global REGISTRY REGISTRY = {} class RegistryError(ValueError): pass __all__ = [ "AVAILABLE_ENTRY_POINTS", "REGISTRY", "create", "Registry", "check_exists", "_get", "_get_all", "_set", "_remove", "_empty_registry", "RegistryError", "importlib_metadata" ] catalogue-2.1.0/catalogue/tests/000077500000000000000000000000001426104053400165635ustar00rootroot00000000000000catalogue-2.1.0/catalogue/tests/__init__.py000066400000000000000000000000001426104053400206620ustar00rootroot00000000000000catalogue-2.1.0/catalogue/tests/conftest.py000066400000000000000000000007501426104053400207640ustar00rootroot00000000000000import pytest def pytest_addoption(parser): parser.addoption("--slow", action="store_true", help="include slow tests") @pytest.fixture() def pathy_fixture(): pytest.importorskip("pathy") import tempfile import shutil from pathy import use_fs, Pathy temp_folder = tempfile.mkdtemp(prefix="thinc-pathy") use_fs(temp_folder) root = Pathy("gs://test-bucket") root.mkdir(exist_ok=True) yield root use_fs(False) shutil.rmtree(temp_folder) catalogue-2.1.0/catalogue/tests/test_catalogue.py000066400000000000000000000121701426104053400221410ustar00rootroot00000000000000from typing import Dict, Tuple, Any import pytest import sys from pathlib import Path import catalogue @pytest.fixture(autouse=True) def cleanup(): # Don't delete entries needed for config tests. for key in set(catalogue.REGISTRY.keys()) - set(filter_registry("config").keys()): catalogue.REGISTRY.pop(key) yield def filter_registry(keep: str) -> Dict[Tuple[str, ...], Any]: """ Filters registry objects for tests. test_mode (str): One of ("catalogue", "config"). Only entries in the registry belonging to the corresponding tests will be returned. RETURNS (Dict[Tuple[str], Any]): entries in registry without those added for config tests. """ assert keep in ("catalogue", "config") return { key: val for key, val in catalogue.REGISTRY.items() if ("config_tests" in key) is (keep == "config") } def test_get_set(): catalogue._set(("a", "b", "c"), "test") assert len(filter_registry("catalogue")) == 1 assert ("a", "b", "c") in catalogue.REGISTRY assert catalogue.check_exists("a", "b", "c") assert catalogue.REGISTRY[("a", "b", "c")] == "test" assert catalogue._get(("a", "b", "c")) == "test" with pytest.raises(catalogue.RegistryError): catalogue._get(("a", "b", "d")) with pytest.raises(catalogue.RegistryError): catalogue._get(("a", "b", "c", "d")) catalogue._set(("x", "y", "z1"), "test1") catalogue._set(("x", "y", "z2"), "test2") assert catalogue._remove(("a", "b", "c")) == "test" catalogue._set(("x", "y2"), "test3") with pytest.raises(catalogue.RegistryError): catalogue._remove(("x", "y")) assert catalogue._remove(("x", "y", "z2")) == "test2" def test_registry_get_set(): test_registry = catalogue.create("test") with pytest.raises(catalogue.RegistryError): test_registry.get("foo") test_registry.register("foo", func=lambda x: x) assert "foo" in test_registry def test_registry_call(): test_registry = catalogue.create("test") test_registry("foo", func=lambda x: x) assert "foo" in test_registry def test_get_all(): catalogue._set(("a", "b", "c"), "test") catalogue._set(("a", "b", "d"), "test") catalogue._set(("a", "b"), "test") catalogue._set(("b", "a"), "test") all_items = catalogue._get_all(("a", "b")) assert len(all_items) == 3 assert ("a", "b", "c") in all_items assert ("a", "b", "d") in all_items assert ("a", "b") in all_items all_items = catalogue._get_all(("a", "b", "c")) assert len(all_items) == 1 assert ("a", "b", "c") in all_items assert len(catalogue._get_all(("a", "b", "c", "d"))) == 0 def test_create_single_namespace(): assert filter_registry("catalogue") == {} test_registry = catalogue.create("test") @test_registry.register("a") def a(): pass def b(): pass test_registry.register("b", func=b) items = test_registry.get_all() assert len(items) == 2 assert items["a"] == a assert items["b"] == b assert catalogue.check_exists("test", "a") assert catalogue.check_exists("test", "b") assert catalogue._get(("test", "a")) == a assert catalogue._get(("test", "b")) == b with pytest.raises(TypeError): # The decorator only accepts one argument @test_registry.register("x", "y") def x(): pass def test_create_multi_namespace(): test_registry = catalogue.create("x", "y") @test_registry.register("z") def z(): pass items = test_registry.get_all() assert len(items) == 1 assert items["z"] == z assert catalogue.check_exists("x", "y", "z") assert catalogue._get(("x", "y", "z")) == z @pytest.mark.skipif(sys.version_info >= (3, 10), reason="Test is not yet updated for 3.10 importlib_metadata API") def test_entry_points(): # Create a new EntryPoint object by pretending we have a setup.cfg and # use one of catalogue's util functions as the advertised function ep_string = "[options.entry_points]test_foo\n bar = catalogue.registry:check_exists" ep = catalogue.importlib_metadata.EntryPoint._from_text(ep_string) catalogue.AVAILABLE_ENTRY_POINTS["test_foo"] = ep assert filter_registry("catalogue") == {} test_registry = catalogue.create("test", "foo", entry_points=True) entry_points = test_registry.get_entry_points() assert "bar" in entry_points assert entry_points["bar"] == catalogue.check_exists assert test_registry.get_entry_point("bar") == catalogue.check_exists assert filter_registry("catalogue") == {} assert test_registry.get("bar") == catalogue.check_exists assert test_registry.get_all() == {"bar": catalogue.check_exists} assert "bar" in test_registry def test_registry_find(): test_registry = catalogue.create("test_registry_find") name = "a" @test_registry.register(name) def a(): """This is a registered function.""" pass info = test_registry.find(name) assert info["module"] == "catalogue.tests.test_catalogue" assert info["file"] == str(Path(__file__)) assert info["docstring"] == "This is a registered function." assert info["line_no"] catalogue-2.1.0/catalogue/tests/test_config.py000066400000000000000000001453651426104053400214570ustar00rootroot00000000000000import inspect import pytest from typing import Dict, Optional, Iterable, Callable, Any, Union from types import GeneratorType import pickle try: import numpy has_numpy = True except ImportError: has_numpy = False from pydantic import BaseModel, StrictFloat, PositiveInt, constr from pydantic.types import StrictBool from catalogue import ConfigValidationError, Config from catalogue.config.util import Generator, partial from catalogue.tests.util import Cat, my_registry, make_tempdir EXAMPLE_CONFIG = """ [optimizer] @optimizers = "Adam.v1" beta1 = 0.9 beta2 = 0.999 use_averages = true [optimizer.learn_rate] @schedules = "warmup_linear.v1" initial_rate = 0.1 warmup_steps = 10000 total_steps = 100000 [pipeline] [pipeline.classifier] name = "classifier" factory = "classifier" [pipeline.classifier.model] @layers = "ClassifierModel.v1" hidden_depth = 1 hidden_width = 64 token_vector_width = 128 [pipeline.classifier.model.embedding] @layers = "Embedding.v1" width = ${pipeline.classifier.model:token_vector_width} """ OPTIMIZER_CFG = """ [optimizer] @optimizers = "Adam.v1" beta1 = 0.9 beta2 = 0.999 use_averages = true [optimizer.learn_rate] @schedules = "warmup_linear.v1" initial_rate = 0.1 warmup_steps = 10000 total_steps = 100000 """ class HelloIntsSchema(BaseModel): hello: int world: int class Config: extra = "forbid" class DefaultsSchema(BaseModel): required: int optional: str = "default value" class Config: extra = "forbid" class ComplexSchema(BaseModel): outer_req: int outer_opt: str = "default value" level2_req: HelloIntsSchema level2_opt: DefaultsSchema = DefaultsSchema(required=1) good_catsie = {"@cats": "catsie.v1", "evil": False, "cute": True} ok_catsie = {"@cats": "catsie.v1", "evil": False, "cute": False} bad_catsie = {"@cats": "catsie.v1", "evil": True, "cute": True} worst_catsie = {"@cats": "catsie.v1", "evil": True, "cute": False} def test_validate_simple_config(): simple_config = {"hello": 1, "world": 2} f, _, v = my_registry._fill(simple_config, HelloIntsSchema) assert f == simple_config assert v == simple_config def test_invalidate_simple_config(): invalid_config = {"hello": 1, "world": "hi!"} with pytest.raises(ConfigValidationError) as exc_info: my_registry._fill(invalid_config, HelloIntsSchema) error = exc_info.value assert len(error.errors) == 1 assert "type_error.integer" in error.error_types def test_invalidate_extra_args(): invalid_config = {"hello": 1, "world": 2, "extra": 3} with pytest.raises(ConfigValidationError): my_registry._fill(invalid_config, HelloIntsSchema) def test_fill_defaults_simple_config(): valid_config = {"required": 1} filled, _, v = my_registry._fill(valid_config, DefaultsSchema) assert filled["required"] == 1 assert filled["optional"] == "default value" invalid_config = {"optional": "some value"} with pytest.raises(ConfigValidationError): my_registry._fill(invalid_config, DefaultsSchema) def test_fill_recursive_config(): valid_config = {"outer_req": 1, "level2_req": {"hello": 4, "world": 7}} filled, _, validation = my_registry._fill(valid_config, ComplexSchema) assert filled["outer_req"] == 1 assert filled["outer_opt"] == "default value" assert filled["level2_req"]["hello"] == 4 assert filled["level2_req"]["world"] == 7 assert filled["level2_opt"]["required"] == 1 assert filled["level2_opt"]["optional"] == "default value" def test_is_promise(): assert my_registry.is_promise(good_catsie) assert not my_registry.is_promise({"hello": "world"}) assert not my_registry.is_promise(1) invalid = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"} assert my_registry.is_promise(invalid) def test_get_constructor(): assert my_registry.get_constructor(good_catsie) == ("cats", "catsie.v1") def test_parse_args(): args, kwargs = my_registry.parse_args(bad_catsie) assert args == [] assert kwargs == {"evil": True, "cute": True} def test_make_promise_schema(): schema = my_registry.make_promise_schema(good_catsie) assert "evil" in schema.__fields__ assert "cute" in schema.__fields__ def test_validate_promise(): config = {"required": 1, "optional": good_catsie} filled, _, validated = my_registry._fill(config, DefaultsSchema) assert filled == config assert validated == {"required": 1, "optional": "meow"} def test_fill_validate_promise(): config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}} filled, _, validated = my_registry._fill(config, DefaultsSchema) assert filled["optional"]["cute"] is True def test_fill_invalidate_promise(): config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}} with pytest.raises(ConfigValidationError): my_registry._fill(config, HelloIntsSchema) config["optional"]["whiskers"] = True with pytest.raises(ConfigValidationError): my_registry._fill(config, DefaultsSchema) def test_create_registry(): with pytest.raises(ValueError): my_registry.create("cats") my_registry.create("dogs") assert hasattr(my_registry, "dogs") assert len(my_registry.dogs.get_all()) == 0 my_registry.dogs.register("good_boy.v1", func=lambda x: x) assert len(my_registry.dogs.get_all()) == 1 with pytest.raises(ValueError): my_registry.create("dogs") def test_registry_methods(): with pytest.raises(ValueError): my_registry.get("dfkoofkds", "catsie.v1") my_registry.cats.register("catsie.v123")(None) with pytest.raises(ValueError): my_registry.get("cats", "catsie.v123") def test_resolve_no_schema(): config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} result = my_registry.resolve({"cfg": config})["cfg"] assert result["one"] == 1 assert result["two"] == {"three": "scratch!"} with pytest.raises(ConfigValidationError): config = {"two": {"three": {"@cats": "catsie.v1", "evil": "true"}}} my_registry.resolve(config) def test_resolve_schema(): class TestBaseSubSchema(BaseModel): three: str class TestBaseSchema(BaseModel): one: PositiveInt two: TestBaseSubSchema class Config: extra = "forbid" class TestSchema(BaseModel): cfg: TestBaseSchema config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} my_registry.resolve({"cfg": config}, schema=TestSchema) config = {"one": -1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}} with pytest.raises(ConfigValidationError): # "one" is not a positive int my_registry.resolve({"cfg": config}, schema=TestSchema) config = {"one": 1, "two": {"four": {"@cats": "catsie.v1", "evil": True}}} with pytest.raises(ConfigValidationError): # "three" is required in subschema my_registry.resolve({"cfg": config}, schema=TestSchema) def test_resolve_schema_coerced(): class TestBaseSchema(BaseModel): test1: str test2: bool test3: float class TestSchema(BaseModel): cfg: TestBaseSchema config = {"test1": 123, "test2": 1, "test3": 5} filled = my_registry.fill({"cfg": config}, schema=TestSchema) result = my_registry.resolve({"cfg": config}, schema=TestSchema) assert result["cfg"] == {"test1": "123", "test2": True, "test3": 5.0} # This only affects the resolved config, not the filled config assert filled["cfg"] == config def test_read_config(): byte_string = EXAMPLE_CONFIG.encode("utf8") cfg = Config().from_bytes(byte_string) assert cfg["optimizer"]["beta1"] == 0.9 assert cfg["optimizer"]["learn_rate"]["initial_rate"] == 0.1 assert cfg["pipeline"]["classifier"]["factory"] == "classifier" assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128 def test_optimizer_config(): cfg = Config().from_str(OPTIMIZER_CFG) optimizer = my_registry.resolve(cfg, validate=True)["optimizer"] assert optimizer.beta1 == 0.9 def test_config_to_str(): cfg = Config().from_str(OPTIMIZER_CFG) assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG) assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() def test_config_to_str_creates_intermediate_blocks(): cfg = Config({"optimizer": {"foo": {"bar": 1}}}) assert ( cfg.to_str().strip() == """ [optimizer] [optimizer.foo] bar = 1 """.strip() ) def test_config_roundtrip_bytes(): cfg = Config().from_str(OPTIMIZER_CFG) cfg_bytes = cfg.to_bytes() new_cfg = Config().from_bytes(cfg_bytes) assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() def test_config_roundtrip_disk(): cfg = Config().from_str(OPTIMIZER_CFG) with make_tempdir() as path: cfg_path = path / "config.cfg" cfg.to_disk(cfg_path) new_cfg = Config().from_disk(cfg_path) assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture): cfg = Config().from_str(OPTIMIZER_CFG) cfg_path = pathy_fixture / "config.cfg" cfg.to_disk(cfg_path) new_cfg = Config().from_disk(cfg_path) assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() def test_config_to_str_invalid_defaults(): """Test that an error is raised if a config contains top-level keys without a section that would otherwise be interpreted as [DEFAULT] (which causes the values to be included in *all* other sections). """ cfg = {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}} with pytest.raises(ConfigValidationError): Config(cfg).to_str() config_str = "[DEFAULT]\none = 1" with pytest.raises(ConfigValidationError): Config().from_str(config_str) def test_validation_custom_types(): def complex_args( rate: StrictFloat, steps: PositiveInt = 10, # type: ignore log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR", ): return None my_registry.create("complex") my_registry.complex("complex.v1")(complex_args) cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"} my_registry.resolve({"config": cfg}) cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": -1, "log_level": "INFO"} with pytest.raises(ConfigValidationError): # steps is not a positive int my_registry.resolve({"config": cfg}) cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "none"} with pytest.raises(ConfigValidationError): # log_level is not a string matching the regex my_registry.resolve({"config": cfg}) cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"} with pytest.raises(ConfigValidationError): # top-level object is promise my_registry.resolve(cfg) with pytest.raises(ConfigValidationError): # top-level object is promise my_registry.fill(cfg) cfg = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"} with pytest.raises(ConfigValidationError): # two constructors my_registry.resolve({"config": cfg}) def test_validation_no_validate(): config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": "false"}}} result = my_registry.resolve({"cfg": config}, validate=False) filled = my_registry.fill({"cfg": config}, validate=False) assert result["cfg"]["one"] == 1 assert result["cfg"]["two"] == {"three": "scratch!"} assert filled["cfg"]["two"]["three"]["evil"] == "false" assert filled["cfg"]["two"]["three"]["cute"] is True def test_validation_fill_defaults(): config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}} result = my_registry.fill(config, validate=False) assert len(result["cfg"]["two"]) == 3 with pytest.raises(ConfigValidationError): # Required arg "evil" is not defined my_registry.fill(config) config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v2", "evil": False}}} # Fill in with new defaults result = my_registry.fill(config) assert len(result["cfg"]["two"]) == 4 assert result["cfg"]["two"]["evil"] is False assert result["cfg"]["two"]["cute"] is True assert result["cfg"]["two"]["cute_level"] == 1 def test_make_config_positional_args(): @my_registry.cats("catsie.v567") def catsie_567(*args: Optional[str], foo: str = "bar"): assert args[0] == "^_^" assert args[1] == "^(*.*)^" assert foo == "baz" return args[0] args = ["^_^", "^(*.*)^"] cfg = {"config": {"@cats": "catsie.v567", "foo": "baz", "*": args}} assert my_registry.resolve(cfg)["config"] == "^_^" def test_make_config_positional_args_complex(): @my_registry.cats("catsie.v890") def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]): assert args[0] == 123 return args[0] cfg = {"config": {"@cats": "catsie.v890", "*": [123, True, 1, False]}} assert my_registry.resolve(cfg)["config"] == 123 cfg = {"config": {"@cats": "catsie.v890", "*": [123, "True"]}} with pytest.raises(ConfigValidationError): # "True" is not a valid boolean or positive int my_registry.resolve(cfg) def test_positional_args_to_from_string(): cfg = """[a]\nb = 1\n* = ["foo","bar"]""" assert Config().from_str(cfg).to_str() == cfg cfg = """[a]\nb = 1\n\n[a.*.bar]\ntest = 2\n\n[a.*.foo]\ntest = 1""" assert Config().from_str(cfg).to_str() == cfg @my_registry.cats("catsie.v666") def catsie_666(*args, meow=False): return args cfg = """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]""" filled = my_registry.fill(Config().from_str(cfg)).to_str() assert filled == """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]\nmeow = false""" resolved = my_registry.resolve(Config().from_str(cfg)) assert resolved == {"a": ("foo", "bar")} cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\nx = 1""" filled = my_registry.fill(Config().from_str(cfg)).to_str() assert filled == """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\nx = 1""" resolved = my_registry.resolve(Config().from_str(cfg)) assert resolved == {"a": ({"x": 1},)} @my_registry.cats("catsie.v777") def catsie_777(y: int = 1): return "meow" * y cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777\"""" filled = my_registry.fill(Config().from_str(cfg)).to_str() expected = """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 1""" assert filled == expected cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 3""" result = my_registry.resolve(Config().from_str(cfg)) assert result == {"a": ("meowmeowmeow",)} def test_validation_generators_iterable(): @my_registry.optimizers("test_optimizer.v1") def test_optimizer_v1(rate: float) -> None: return None @my_registry.schedules("test_schedule.v1") def test_schedule_v1(some_value: float = 1.0) -> Iterable[float]: while True: yield some_value config = {"optimizer": {"@optimizers": "test_optimizer.v1", "rate": 0.1}} my_registry.resolve(config) def test_validation_unset_type_hints(): """Test that unset type hints are handled correctly (and treated as Any).""" @my_registry.optimizers("test_optimizer.v2") def test_optimizer_v2(rate, steps: int = 10) -> None: return None config = {"test": {"@optimizers": "test_optimizer.v2", "rate": 0.1, "steps": 20}} my_registry.resolve(config) def test_validation_bad_function(): @my_registry.optimizers("bad.v1") def bad() -> None: raise ValueError("This is an error in the function") return None @my_registry.optimizers("good.v1") def good() -> None: return None # Bad function config = {"test": {"@optimizers": "bad.v1"}} with pytest.raises(ValueError): my_registry.resolve(config) # Bad function call config = {"test": {"@optimizers": "good.v1", "invalid_arg": 1}} with pytest.raises(ConfigValidationError): my_registry.resolve(config) def test_objects_from_config(): config = { "optimizer": { "@optimizers": "my_cool_optimizer.v1", "beta1": 0.2, "learn_rate": { "@schedules": "my_cool_repetitive_schedule.v1", "base_rate": 0.001, "repeat": 4, }, } } optimizer = my_registry.resolve(config)["optimizer"] assert optimizer.beta1 == 0.2 assert optimizer.learn_rate == [0.001] * 4 @pytest.mark.skipif(not has_numpy, reason="needs numpy") def test_partials_from_config(): """Test that functions registered with partial applications are handled correctly (e.g. initializers).""" name = "uniform_init.v1" cfg = {"test": {"@initializers": name, "lo": -0.2}} func = my_registry.resolve(cfg)["test"] assert hasattr(func, "__call__") # The partial will still have lo as an arg, just with default assert len(inspect.signature(func).parameters) == 3 # Make sure returned partial function has correct value set assert inspect.signature(func).parameters["lo"].default == -0.2 # Actually call the function and verify assert numpy.asarray(func((2, 3))).shape == (2, 3) # Make sure validation still works bad_cfg = {"test": {"@initializers": name, "lo": [0.5]}} with pytest.raises(ConfigValidationError): my_registry.resolve(bad_cfg) bad_cfg = {"test": {"@initializers": name, "lo": -0.2, "other": 10}} with pytest.raises(ConfigValidationError): my_registry.resolve(bad_cfg) def test_partials_from_config_nested(): """Test that partial functions are passed correctly to other registered functions that consume them (e.g. initializers -> layers).""" def test_initializer(a: int, b: int = 1) -> int: return a * b @my_registry.initializers("test_initializer.v1") def configure_test_initializer(b: int = 1) -> Callable[[int], int]: return partial(test_initializer, b=b) @my_registry.layers("test_layer.v1") def test_layer(init: Callable[[int], int], c: int = 1) -> Callable[[int], int]: return lambda x: x + init(c) cfg = { "@layers": "test_layer.v1", "c": 5, "init": {"@initializers": "test_initializer.v1", "b": 10}, } func = my_registry.resolve({"test": cfg})["test"] assert func(1) == 51 assert func(100) == 150 def test_validate_generator(): """Test that generator replacement for validation in config doesn't actually replace the returned value.""" @my_registry.schedules("test_schedule.v2") def test_schedule(): while True: yield 10 cfg = {"@schedules": "test_schedule.v2"} result = my_registry.resolve({"test": cfg})["test"] assert isinstance(result, GeneratorType) @my_registry.optimizers("test_optimizer.v2") def test_optimizer2(rate: Generator) -> Generator: return rate cfg = { "@optimizers": "test_optimizer.v2", "rate": {"@schedules": "test_schedule.v2"}, } result = my_registry.resolve({"test": cfg})["test"] assert isinstance(result, GeneratorType) @my_registry.optimizers("test_optimizer.v3") def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: return schedules["rate"] cfg = { "@optimizers": "test_optimizer.v3", "schedules": {"rate": {"@schedules": "test_schedule.v2"}}, } result = my_registry.resolve({"test": cfg})["test"] assert isinstance(result, GeneratorType) @my_registry.optimizers("test_optimizer.v4") def test_optimizer4(*schedules: Generator) -> Generator: return schedules[0] def test_handle_generic_type(): """Test that validation can handle checks against arbitrary generic types in function argument annotations.""" cfg = {"@cats": "generic_cat.v1", "cat": {"@cats": "int_cat.v1", "value_in": 3}} cat = my_registry.resolve({"test": cfg})["test"] assert isinstance(cat, Cat) assert cat.value_in == 3 assert cat.value_out is None assert cat.name == "generic_cat" @pytest.mark.parametrize( "cfg", [ "[a]\nb = 1\nc = 2\n\n[a.c]\nd = 3", "[a]\nb = 1\n\n[a.c]\nd = 2\n\n[a.c.d]\ne = 3", ], ) def test_handle_error_duplicate_keys(cfg): """This would cause very cryptic error when interpreting config. (TypeError: 'X' object does not support item assignment) """ with pytest.raises(ConfigValidationError): Config().from_str(cfg) @pytest.mark.parametrize( "cfg,is_valid", [("[a]\nb = 1\n\n[a.c]\nd = 3", True), ("[a]\nb = 1\n\n[A.c]\nd = 2", False)], ) def test_cant_expand_undefined_block(cfg, is_valid): """Test that you can't expand a block that hasn't been created yet. This comes up when you typo a name, and if we allow expansion of undefined blocks, it's very hard to create good errors for those typos. """ if is_valid: Config().from_str(cfg) else: with pytest.raises(ConfigValidationError): Config().from_str(cfg) def test_fill_config_overrides(): config = { "cfg": { "one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}}, } } overrides = {"cfg.two.three.evil": False} result = my_registry.fill(config, overrides=overrides, validate=True) assert result["cfg"]["two"]["three"]["evil"] is False # Test that promises can be overwritten as well overrides = {"cfg.two.three": 3} result = my_registry.fill(config, overrides=overrides, validate=True) assert result["cfg"]["two"]["three"] == 3 # Test that value can be overwritten with promises and that the result is # interpreted and filled correctly overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}} result = my_registry.fill(config, overrides=overrides) assert result["cfg"]["two"] is None assert result["cfg"]["one"]["@cats"] == "catsie.v1" assert result["cfg"]["one"]["evil"] is False assert result["cfg"]["one"]["cute"] is True # Overwriting with wrong types should cause validation error with pytest.raises(ConfigValidationError): overrides = {"cfg.two.three.evil": 20} my_registry.fill(config, overrides=overrides, validate=True) # Overwriting with incomplete promises should cause validation error with pytest.raises(ConfigValidationError): overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}} my_registry.fill(config, overrides=overrides) # Overrides that don't match config should raise error with pytest.raises(ConfigValidationError): overrides = {"cfg.two.three.evil": False, "two.four": True} my_registry.fill(config, overrides=overrides, validate=True) with pytest.raises(ConfigValidationError): overrides = {"cfg.five": False} my_registry.fill(config, overrides=overrides, validate=True) def test_resolve_overrides(): config = { "cfg": { "one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}}, } } overrides = {"cfg.two.three.evil": False} result = my_registry.resolve(config, overrides=overrides, validate=True) assert result["cfg"]["two"]["three"] == "meow" # Test that promises can be overwritten as well overrides = {"cfg.two.three": 3} result = my_registry.resolve(config, overrides=overrides, validate=True) assert result["cfg"]["two"]["three"] == 3 # Test that value can be overwritten with promises overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}} result = my_registry.resolve(config, overrides=overrides) assert result["cfg"]["one"] == "meow" assert result["cfg"]["two"] is None # Overwriting with wrong types should cause validation error with pytest.raises(ConfigValidationError): overrides = {"cfg.two.three.evil": 20} my_registry.resolve(config, overrides=overrides, validate=True) # Overwriting with incomplete promises should cause validation error with pytest.raises(ConfigValidationError): overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}} my_registry.resolve(config, overrides=overrides) # Overrides that don't match config should raise error with pytest.raises(ConfigValidationError): overrides = {"cfg.two.three.evil": False, "cfg.two.four": True} my_registry.resolve(config, overrides=overrides, validate=True) with pytest.raises(ConfigValidationError): overrides = {"cfg.five": False} my_registry.resolve(config, overrides=overrides, validate=True) @pytest.mark.parametrize( "prop,expected", [("a.b.c", True), ("a.b", True), ("a", True), ("a.e", True), ("a.b.c.d", False)], ) def test_is_in_config(prop, expected): config = {"a": {"b": {"c": 5, "d": 6}, "e": [1, 2]}} assert my_registry._is_in_config(prop, config) is expected def test_resolve_prefilled_values(): class Language(object): def __init__(self): ... @my_registry.optimizers("prefilled.v1") def prefilled(nlp: Language, value: int = 10): return (nlp, value) # Passing an instance of Language here via the config is bad, since it # won't serialize to a string, but we still test for it config = {"test": {"@optimizers": "prefilled.v1", "nlp": Language(), "value": 50}} resolved = my_registry.resolve(config, validate=True) result = resolved["test"] assert isinstance(result[0], Language) assert result[1] == 50 def test_fill_config_dict_return_type(): """Test that a registered function returning a dict is handled correctly.""" @my_registry.cats.register("catsie_with_dict.v1") def catsie_with_dict(evil: StrictBool) -> Dict[str, bool]: return {"not_evil": not evil} config = {"test": {"@cats": "catsie_with_dict.v1", "evil": False}, "foo": 10} result = my_registry.fill({"cfg": config}, validate=True)["cfg"]["test"] assert result["evil"] is False assert "not_evil" not in result result = my_registry.resolve({"cfg": config}, validate=True)["cfg"]["test"] assert result["not_evil"] is True @pytest.mark.skipif(not has_numpy, reason="needs numpy") def test_deepcopy_config(): config = Config({"a": 1, "b": {"c": 2, "d": 3}}) copied = config.copy() # Same values but not same object assert config == copied assert config is not copied # Check for error if value can't be pickled/deepcopied config = Config({"a": 1, "b": numpy}) with pytest.raises(ValueError): config.copy() def test_config_to_str_simple_promises(): """Test that references to function registries without arguments are serialized inline as dict.""" config_str = """[section]\nsubsection = {"@registry":"value"}""" config = Config().from_str(config_str) assert config["section"]["subsection"]["@registry"] == "value" assert config.to_str() == config_str def test_config_from_str_invalid_section(): config_str = """[a]\nb = null\n\n[a.b]\nc = 1""" with pytest.raises(ConfigValidationError): Config().from_str(config_str) config_str = """[a]\nb = null\n\n[a.b.c]\nd = 1""" with pytest.raises(ConfigValidationError): Config().from_str(config_str) def test_config_to_str_order(): """Test that Config.to_str orders the sections.""" config = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": {"g": {"h": {"i": 4, "j": 5}}}} expected = ( "[a]\ne = 3\n\n[a.b]\nc = 1\nd = 2\n\n[f]\n\n[f.g]\n\n[f.g.h]\ni = 4\nj = 5" ) config = Config(config) assert config.to_str() == expected @pytest.mark.parametrize("d", [".", ":"]) def test_config_interpolation(d): """Test that config values are interpolated correctly. The parametrized value is the final divider (${a.b} vs. ${a:b}). Both should now work and be valid. The double {{ }} in the config strings are required to prevent the references from being interpreted as an actual f-string variable. """ c_str = """[a]\nfoo = "hello"\n\n[b]\nbar = ${foo}""" with pytest.raises(ConfigValidationError): Config().from_str(c_str) c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}""" assert Config().from_str(c_str)["b"]["bar"] == "hello" c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}!""" assert Config().from_str(c_str)["b"]["bar"] == "hello!" c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = "${{a{d}foo}}!\"""" assert Config().from_str(c_str)["b"]["bar"] == "hello!" c_str = f"""[a]\nfoo = 15\n\n[b]\nbar = ${{a{d}foo}}!""" assert Config().from_str(c_str)["b"]["bar"] == "15!" c_str = f"""[a]\nfoo = ["x", "y"]\n\n[b]\nbar = ${{a{d}foo}}""" assert Config().from_str(c_str)["b"]["bar"] == ["x", "y"] # Interpolation within the same section c_str = f"""[a]\nfoo = "x"\nbar = ${{a{d}foo}}\nbaz = "${{a{d}foo}}y\"""" assert Config().from_str(c_str)["a"]["bar"] == "x" assert Config().from_str(c_str)["a"]["baz"] == "xy" def test_config_interpolation_lists(): # Test that lists are preserved correctly c_str = """[a]\nb = 1\n\n[c]\nd = ["hello ${a.b}", "world"]""" config = Config().from_str(c_str, interpolate=False) assert config["c"]["d"] == ["hello ${a.b}", "world"] config = config.interpolate() assert config["c"]["d"] == ["hello 1", "world"] c_str = """[a]\nb = 1\n\n[c]\nd = [${a.b}, "hello ${a.b}", "world"]""" config = Config().from_str(c_str) assert config["c"]["d"] == [1, "hello 1", "world"] config = Config().from_str(c_str, interpolate=False) # NOTE: This currently doesn't work, because we can't know how to JSON-load # the uninterpolated list [${a.b}]. # assert config["c"]["d"] == ["${a.b}", "hello ${a.b}", "world"] # config = config.interpolate() # assert config["c"]["d"] == [1, "hello 1", "world"] c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", ${a}]""" config = Config().from_str(c_str) assert config["c"]["d"] == ["hello", {"b": 1}] c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", "hello ${a}"]""" with pytest.raises(ConfigValidationError): Config().from_str(c_str) config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": ["hello ${a.b}"], "y": 2}]""" config = Config().from_str(config_str) assert config["c"]["d"] == ["hello", {"x": ["hello 1"], "y": 2}] config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": [${a.b}], "y": 2}]""" with pytest.raises(ConfigValidationError): Config().from_str(c_str) @pytest.mark.parametrize("d", [".", ":"]) def test_config_interpolation_sections(d): """Test that config sections are interpolated correctly. The parametrized value is the final divider (${a.b} vs. ${a:b}). Both should now work and be valid. The double {{ }} in the config strings are required to prevent the references from being interpreted as an actual f-string variable. """ # Simple block references c_str = """[a]\nfoo = "hello"\nbar = "world"\n\n[b]\nc = ${a}""" config = Config().from_str(c_str) assert config["b"]["c"] == config["a"] # References with non-string values c_str = f"""[a]\nfoo = "hello"\n\n[a.x]\ny = ${{a{d}b}}\n\n[a.b]\nc = 1\nd = [10]""" config = Config().from_str(c_str) assert config["a"]["x"]["y"] == config["a"]["b"] # Multiple references in the same string c_str = f"""[a]\nx = "string"\ny = 10\n\n[b]\nz = "${{a{d}x}}/${{a{d}y}}\"""" config = Config().from_str(c_str) assert config["b"]["z"] == "string/10" # Non-string references in string (converted to string) c_str = f"""[a]\nx = ["hello", "world"]\n\n[b]\ny = "result: ${{a{d}x}}\"""" config = Config().from_str(c_str) assert config["b"]["y"] == 'result: ["hello", "world"]' # References to sections referencing sections c_str = """[a]\nfoo = "x"\n\n[b]\nbar = ${a}\n\n[c]\nbaz = ${b}""" config = Config().from_str(c_str) assert config["b"]["bar"] == config["a"] assert config["c"]["baz"] == config["b"] # References to section values referencing other sections c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b{d}bar}}""" config = Config().from_str(c_str) assert config["c"]["baz"] == config["b"]["bar"] # References to sections with subsections c_str = """[a]\nfoo = "x"\n\n[a.b]\nbar = 100\n\n[c]\nbaz = ${a}""" config = Config().from_str(c_str) assert config["c"]["baz"] == config["a"] # Infinite recursion c_str = """[a]\nfoo ="x"\n\n[a.b]\nbar = ${a}""" config = Config().from_str(c_str) assert config["a"]["b"]["bar"] == config["a"] c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b.bar{d}foo}}""" # We can't reference not-yet interpolated subsections with pytest.raises(ConfigValidationError): Config().from_str(c_str) # Generally invalid references c_str = f"""[a]\nfoo = ${{b{d}bar}}""" with pytest.raises(ConfigValidationError): Config().from_str(c_str) # We can't reference sections or promises within strings c_str = """[a]\n\n[a.b]\nfoo = "x: ${c}"\n\n[c]\nbar = 1\nbaz = 2""" with pytest.raises(ConfigValidationError): Config().from_str(c_str) def test_config_from_str_overrides(): config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = 3\n\n[f]\ng = {"x": "y"}""" # Basic value substitution overrides = {"a.b": 10, "a.c.d": 20} config = Config().from_str(config_str, overrides=overrides) assert config["a"]["b"] == 10 assert config["a"]["c"]["d"] == 20 assert config["a"]["c"]["e"] == 3 # Valid values that previously weren't in config config = Config().from_str(config_str, overrides={"a.c.f": 100}) assert config["a"]["c"]["d"] == 2 assert config["a"]["c"]["e"] == 3 assert config["a"]["c"]["f"] == 100 # Invalid keys and sections with pytest.raises(ConfigValidationError): Config().from_str(config_str, overrides={"f": 10}) # This currently isn't expected to work, because the dict in f.g is not # interpreted as a section while the config is still just the configparser with pytest.raises(ConfigValidationError): Config().from_str(config_str, overrides={"f.g.x": "z"}) # With variables (values) config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = ${a:b}""" config = Config().from_str(config_str, overrides={"a.b": 10}) assert config["a"]["b"] == 10 assert config["a"]["c"]["e"] == 10 # With variables (sections) config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\n[e]\nf = ${a.c}""" config = Config().from_str(config_str, overrides={"a.c.d": 20}) assert config["a"]["c"]["d"] == 20 assert config["e"]["f"] == {"d": 20} def test_config_reserved_aliases(): """Test that the auto-generated pydantic schemas auto-alias reserved attributes like "validate" that would otherwise cause NameError.""" @my_registry.cats("catsie.with_alias") def catsie_with_alias(validate: StrictBool = False): return validate cfg = {"@cats": "catsie.with_alias", "validate": True} resolved = my_registry.resolve({"test": cfg}) filled = my_registry.fill({"test": cfg}) assert resolved["test"] is True assert filled["test"] == cfg cfg = {"@cats": "catsie.with_alias", "validate": 20} with pytest.raises(ConfigValidationError): my_registry.resolve({"test": cfg}) @pytest.mark.skipif(not has_numpy, reason="needs numpy") @pytest.mark.parametrize("d", [".", ":"]) def test_config_no_interpolation(d): """Test that interpolation is correctly preserved. The parametrized value is the final divider (${a.b} vs. ${a:b}). Both should now work and be valid. The double {{ }} in the config strings are required to prevent the references from being interpreted as an actual f-string variable. """ c_str = f"""[a]\nb = 1\n\n[c]\nd = ${{a{d}b}}\ne = \"hello${{a{d}b}}"\nf = ${{a}}""" config = Config().from_str(c_str, interpolate=False) assert not config.is_interpolated assert config["c"]["d"] == f"${{a{d}b}}" assert config["c"]["e"] == f'"hello${{a{d}b}}"' assert config["c"]["f"] == "${a}" config2 = Config().from_str(config.to_str(), interpolate=True) assert config2.is_interpolated assert config2["c"]["d"] == 1 assert config2["c"]["e"] == "hello1" assert config2["c"]["f"] == {"b": 1} config3 = config.interpolate() assert config3.is_interpolated assert config3["c"]["d"] == 1 assert config3["c"]["e"] == "hello1" assert config3["c"]["f"] == {"b": 1} # Bad non-serializable value cfg = {"x": {"y": numpy.asarray([[1, 2], [4, 5]], dtype="f"), "z": f"${{x{d}y}}"}} with pytest.raises(ConfigValidationError): Config(cfg).interpolate() def test_config_no_interpolation_registry(): config_str = """[a]\nbad = true\n[b]\n@cats = "catsie.v1"\nevil = ${a:bad}\n\n[c]\n d = ${b}""" config = Config().from_str(config_str, interpolate=False) assert not config.is_interpolated assert config["b"]["evil"] == "${a:bad}" assert config["c"]["d"] == "${b}" filled = my_registry.fill(config) resolved = my_registry.resolve(config) assert resolved["b"] == "scratch!" assert resolved["c"]["d"] == "scratch!" assert filled["b"]["evil"] == "${a:bad}" assert filled["b"]["cute"] is True assert filled["c"]["d"] == "${b}" interpolated = filled.interpolate() assert interpolated.is_interpolated assert interpolated["b"]["evil"] is True assert interpolated["c"]["d"] == interpolated["b"] config = Config().from_str(config_str, interpolate=True) assert config.is_interpolated filled = my_registry.fill(config) resolved = my_registry.resolve(config) assert resolved["b"] == "scratch!" assert resolved["c"]["d"] == "scratch!" assert filled["b"]["evil"] is True assert filled["c"]["d"] == filled["b"] # Resolving a non-interpolated filled config config = Config().from_str(config_str, interpolate=False) assert not config.is_interpolated filled = my_registry.fill(config) assert not filled.is_interpolated assert filled["c"]["d"] == "${b}" resolved = my_registry.resolve(filled) assert resolved["c"]["d"] == "scratch!" def test_config_deep_merge(): config = {"a": "hello", "b": {"c": "d"}} defaults = {"a": "world", "b": {"c": "e", "f": "g"}} merged = Config(defaults).merge(config) assert len(merged) == 2 assert merged["a"] == "hello" assert merged["b"] == {"c": "d", "f": "g"} config = {"a": "hello", "b": {"@test": "x", "foo": 1}} defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100} merged = Config(defaults).merge(config) assert len(merged) == 3 assert merged["a"] == "hello" assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2} assert merged["c"] == 100 config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100} defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} merged = Config(defaults).merge(config) assert len(merged) == 3 assert merged["a"] == "hello" assert merged["b"] == {"@test": "x", "foo": 1} assert merged["c"] == 100 # Test that leaving out the factory just adds to existing config = {"a": "hello", "b": {"foo": 1}, "c": 100} defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}} merged = Config(defaults).merge(config) assert len(merged) == 3 assert merged["a"] == "hello" assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2} assert merged["c"] == 100 # Test that switching to a different factory prevents the default from being added config = {"a": "hello", "b": {"@foo": 1}, "c": 100} defaults = {"a": "world", "b": {"@bar": "y"}} merged = Config(defaults).merge(config) assert len(merged) == 3 assert merged["a"] == "hello" assert merged["b"] == {"@foo": 1} assert merged["c"] == 100 config = {"a": "hello", "b": {"@foo": 1}, "c": 100} defaults = {"a": "world", "b": "y"} merged = Config(defaults).merge(config) assert len(merged) == 3 assert merged["a"] == "hello" assert merged["b"] == {"@foo": 1} assert merged["c"] == 100 def test_config_deep_merge_variables(): config_str = """[a]\nb= 1\nc = 2\n\n[d]\ne = ${a:b}""" defaults_str = """[a]\nx = 100\n\n[d]\ny = 500""" config = Config().from_str(config_str, interpolate=False) defaults = Config().from_str(defaults_str) merged = defaults.merge(config) assert merged["a"] == {"b": 1, "c": 2, "x": 100} assert merged["d"] == {"e": "${a:b}", "y": 500} assert merged.interpolate()["d"] == {"e": 1, "y": 500} # With variable in defaults: overwritten by new value config = Config().from_str("""[a]\nb= 1\nc = 2""") defaults = Config().from_str("""[a]\nb = 100\nc = ${a:b}""", interpolate=False) merged = defaults.merge(config) assert merged["a"]["c"] == 2 @pytest.mark.skipif(not has_numpy, reason="needs numpy") def test_config_to_str_roundtrip(): cfg = {"cfg": {"foo": False}} config_str = Config(cfg).to_str() assert config_str == "[cfg]\nfoo = false" config = Config().from_str(config_str) assert dict(config) == cfg cfg = {"cfg": {"foo": "false"}} config_str = Config(cfg).to_str() assert config_str == '[cfg]\nfoo = "false"' config = Config().from_str(config_str) assert dict(config) == cfg # Bad non-serializable value cfg = {"cfg": {"x": numpy.asarray([[1, 2, 3, 4], [4, 5, 3, 4]], dtype="f")}} config = Config(cfg) with pytest.raises(ConfigValidationError): config.to_str() # Roundtrip with variables: preserve variables correctly (quoted/unquoted) config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = "${a:b}\"""" config = Config().from_str(config_str, interpolate=False) assert config.to_str() == config_str def test_config_is_interpolated(): """Test that a config object correctly reports whether it's interpolated.""" config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = ${a}""" config = Config().from_str(config_str, interpolate=False) assert not config.is_interpolated config = config.merge(Config({"x": {"y": "z"}})) assert not config.is_interpolated config = Config(config) assert not config.is_interpolated config = config.interpolate() assert config.is_interpolated config = config.merge(Config().from_str(config_str, interpolate=False)) assert not config.is_interpolated @pytest.mark.parametrize( "section_order,expected_str,expected_keys", [ # fmt: off ([], "[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[h]\ni = 5\n\n[j]\nk = 6", ["a", "h", "j"]), (["j", "h", "a"], "[j]\nk = 6\n\n[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4", ["j", "h", "a"]), (["h"], "[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[j]\nk = 6", ["h", "a", "j"]) # fmt: on ], ) def test_config_serialize_custom_sort(section_order, expected_str, expected_keys): cfg = { "j": {"k": 6}, "a": {"b": 1, "d": {"e": 3}, "c": 2, "f": {"g": 4}}, "h": {"i": 5}, } cfg_str = Config(cfg).to_str() assert Config(cfg, section_order=section_order).to_str() == expected_str keys = list(Config(section_order=section_order).from_str(cfg_str).keys()) assert keys == expected_keys keys = list(Config(cfg, section_order=section_order).keys()) assert keys == expected_keys def test_config_custom_sort_preserve(): """Test that sort order is preserved when merging and copying configs, or when configs are filled and resolved.""" cfg = {"x": {}, "y": {}, "z": {}} section_order = ["y", "z", "x"] expected = "[y]\n\n[z]\n\n[x]" config = Config(cfg, section_order=section_order) assert config.to_str() == expected config2 = config.copy() assert config2.to_str() == expected config3 = config.merge({"a": {}}) assert config3.to_str() == f"{expected}\n\n[a]" config4 = Config(config) assert config4.to_str() == expected config_str = """[a]\nb = 1\n[c]\n@cats = "catsie.v1"\nevil = true\n\n[t]\n x = 2""" section_order = ["c", "a", "t"] config5 = Config(section_order=section_order).from_str(config_str) assert list(config5.keys()) == section_order filled = my_registry.fill(config5) assert filled.section_order == section_order def test_config_pickle(): config = Config({"foo": "bar"}, section_order=["foo", "bar", "baz"]) data = pickle.dumps(config) config_new = pickle.loads(data) assert config_new == {"foo": "bar"} assert config_new.section_order == ["foo", "bar", "baz"] def test_config_fill_extra_fields(): """Test that filling a config from a schema removes extra fields.""" class TestSchemaContent(BaseModel): a: str b: int class Config: extra = "forbid" class TestSchema(BaseModel): cfg: TestSchemaContent config = Config({"cfg": {"a": "1", "b": 2, "c": True}}) with pytest.raises(ConfigValidationError): my_registry.fill(config, schema=TestSchema) filled = my_registry.fill(config, schema=TestSchema, validate=False)["cfg"] assert filled == {"a": "1", "b": 2} config2 = config.interpolate() filled = my_registry.fill(config2, schema=TestSchema, validate=False)["cfg"] assert filled == {"a": "1", "b": 2} config3 = Config({"cfg": {"a": "1", "b": 2, "c": True}}, is_interpolated=False) filled = my_registry.fill(config3, schema=TestSchema, validate=False)["cfg"] assert filled == {"a": "1", "b": 2} class TestSchemaContent2(BaseModel): a: str b: int class Config: extra = "allow" class TestSchema2(BaseModel): cfg: TestSchemaContent2 filled = my_registry.fill(config, schema=TestSchema2, validate=False)["cfg"] assert filled == {"a": "1", "b": 2, "c": True} def test_config_validation_error_custom(): class Schema(BaseModel): hello: int world: int config = {"hello": 1, "world": "hi!"} with pytest.raises(ConfigValidationError) as exc_info: my_registry._fill(config, Schema) e1 = exc_info.value assert e1.title == "Config validation error" assert e1.desc is None assert not e1.parent assert e1.show_config is True assert len(e1.errors) == 1 assert e1.errors[0]["loc"] == ("world",) assert e1.errors[0]["msg"] == "value is not a valid integer" assert e1.errors[0]["type"] == "type_error.integer" assert e1.error_types == set(["type_error.integer"]) # Create a new error with overrides title = "Custom error" desc = "Some error description here" e2 = ConfigValidationError.from_error(e1, title=title, desc=desc, show_config=False) assert e2.errors == e1.errors assert e2.error_types == e1.error_types assert e2.title == title assert e2.desc == desc assert e2.show_config is False assert e1.text != e2.text def test_config_parsing_error(): config_str = "[a]\nb c" with pytest.raises(ConfigValidationError): Config().from_str(config_str) def test_config_fill_without_resolve(): class BaseSchema(BaseModel): catsie: int config = {"catsie": {"@cats": "catsie.v1", "evil": False}} filled = my_registry.fill(config) resolved = my_registry.resolve(config) assert resolved["catsie"] == "meow" assert filled["catsie"]["cute"] is True with pytest.raises(ConfigValidationError): my_registry.resolve(config, schema=BaseSchema) filled2 = my_registry.fill(config, schema=BaseSchema) assert filled2["catsie"]["cute"] is True resolved = my_registry.resolve(filled2) assert resolved["catsie"] == "meow" # With unavailable function class BaseSchema2(BaseModel): catsie: Any other: int = 12 config = {"catsie": {"@cats": "dog", "evil": False}} filled3 = my_registry.fill(config, schema=BaseSchema2) assert filled3["catsie"] == config["catsie"] assert filled3["other"] == 12 def test_config_dataclasses(): cat = Cat("testcat", value_in=1, value_out=2) config = {"cfg": {"@cats": "catsie.v3", "arg": cat}} result = my_registry.resolve(config)["cfg"] assert isinstance(result, Cat) assert result.name == cat.name assert result.value_in == cat.value_in assert result.value_out == cat.value_out @pytest.mark.parametrize( "greeting,value,expected", [ # simple substitution should go fine [342, "${vars.a}", int], ["342", "${vars.a}", str], ["everyone", "${vars.a}", str], ], ) def test_config_interpolates(greeting, value, expected): str_cfg = f""" [project] my_par = {value} [vars] a = "something" """ overrides = {"vars.a": greeting} cfg = Config().from_str(str_cfg, overrides=overrides) assert type(cfg["project"]["my_par"]) == expected @pytest.mark.parametrize( "greeting,value,expected", [ # fmt: off # simple substitution should go fine ["hello 342", "${vars.a}", "hello 342"], ["hello everyone", "${vars.a}", "hello everyone"], ["hello tout le monde", "${vars.a}", "hello tout le monde"], ["hello 42", "${vars.a}", "hello 42"], # substituting an element in a list ["hello 342", "[1, ${vars.a}, 3]", "hello 342"], ["hello everyone", "[1, ${vars.a}, 3]", "hello everyone"], ["hello tout le monde", "[1, ${vars.a}, 3]", "hello tout le monde"], ["hello 42", "[1, ${vars.a}, 3]", "hello 42"], # substituting part of a string [342, "hello ${vars.a}", "hello 342"], ["everyone", "hello ${vars.a}", "hello everyone"], ["tout le monde", "hello ${vars.a}", "hello tout le monde"], pytest.param("42", "hello ${vars.a}", "hello 42", marks=pytest.mark.xfail), # substituting part of a implicit string inside a list [342, "[1, hello ${vars.a}, 3]", "hello 342"], ["everyone", "[1, hello ${vars.a}, 3]", "hello everyone"], ["tout le monde", "[1, hello ${vars.a}, 3]", "hello tout le monde"], pytest.param("42", "[1, hello ${vars.a}, 3]", "hello 42", marks=pytest.mark.xfail), # substituting part of a explicit string inside a list [342, "[1, 'hello ${vars.a}', '3']", "hello 342"], ["everyone", "[1, 'hello ${vars.a}', '3']", "hello everyone"], ["tout le monde", "[1, 'hello ${vars.a}', '3']", "hello tout le monde"], pytest.param("42", "[1, 'hello ${vars.a}', '3']", "hello 42", marks=pytest.mark.xfail), # more complicated example [342, "[{'name':'x','script':['hello ${vars.a}']}]", "hello 342"], ["everyone", "[{'name':'x','script':['hello ${vars.a}']}]", "hello everyone"], ["tout le monde", "[{'name':'x','script':['hello ${vars.a}']}]", "hello tout le monde"], pytest.param("42", "[{'name':'x','script':['hello ${vars.a}']}]", "hello 42", marks=pytest.mark.xfail), # fmt: on ], ) def test_config_overrides(greeting, value, expected): str_cfg = f""" [project] commands = {value} [vars] a = "world" """ overrides = {"vars.a": greeting} assert "${vars.a}" in str_cfg cfg = Config().from_str(str_cfg, overrides=overrides) assert expected in str(cfg) catalogue-2.1.0/catalogue/tests/util.py000066400000000000000000000077151426104053400201240ustar00rootroot00000000000000""" Registered functions used for config tests. """ import contextlib import dataclasses import shutil import tempfile from pathlib import Path from typing import Iterable, List, Union, Generator, Callable, Tuple, Generic, TypeVar, Optional, Any import numpy from pydantic.types import StrictBool import catalogue from catalogue.config import config from catalogue.config.util import partial FloatOrSeq = Union[float, List[float], Generator] InT = TypeVar("InT") OutT = TypeVar("OutT") @dataclasses.dataclass class Cat(Generic[InT, OutT]): name: str value_in: InT value_out: OutT class my_registry(config.catalogue_registry): cats = catalogue.registry.create("config_tests", "cats", entry_points=False) optimizers = catalogue.registry.create("config_tests", "optimizers", entry_points=False) schedules = catalogue.registry.create("config_tests", "schedules", entry_points=False) initializers = catalogue.registry.create("config_tests", "initializers", entry_points=False) layers = catalogue.registry.create("config_tests", "layers", entry_points=False) @my_registry.cats.register("catsie.v1") def catsie_v1(evil: StrictBool, cute: bool = True) -> str: if evil: return "scratch!" else: return "meow" @my_registry.cats.register("catsie.v2") def catsie_v2(evil: StrictBool, cute: bool = True, cute_level: int = 1) -> str: if evil: return "scratch!" else: if cute_level > 2: return "meow <3" return "meow" @my_registry.cats("catsie.v3") def catsie(arg: Cat) -> Cat: return arg @my_registry.optimizers("Adam.v1") def Adam( learn_rate: FloatOrSeq = 0.001, *, beta1: FloatOrSeq = .001, beta2: FloatOrSeq = .001, use_averages: bool = True, ): """ Mocks optimizer generation. Note that the returned object is not actually an optimizer. This function is merely used to illustrate how to use the function registry, e.g. with thinc. """ @dataclasses.dataclass class Optimizer: learn_rate: FloatOrSeq beta1: FloatOrSeq beta2: FloatOrSeq use_averages: bool return Optimizer(learn_rate=learn_rate, beta1=beta1, beta2=beta2, use_averages=use_averages) @my_registry.schedules("warmup_linear.v1") def warmup_linear( initial_rate: float, warmup_steps: int, total_steps: int ) -> Iterable[float]: """Generate a series, starting from an initial rate, and then with a warmup period, and then a linear decline. Used for learning rates. """ step = 0 while True: if step < warmup_steps: factor = step / max(1, warmup_steps) else: factor = max( 0.0, (total_steps - step) / max(1.0, total_steps - warmup_steps) ) yield factor * initial_rate step += 1 def uniform_init( shape: Tuple[int, ...], *, lo: float = -0.1, hi: float = 0.1 ) -> List[float]: return numpy.random.uniform(lo, hi, shape).tolist() @my_registry.initializers("uniform_init.v1") def configure_uniform_init( *, lo: float = -0.1, hi: float = 0.1 ) -> Callable[[List[float]], List[float]]: return partial(uniform_init, lo=lo, hi=hi) @my_registry.cats("generic_cat.v1") def generic_cat(cat: Cat[int, int]) -> Cat[int, int]: cat.name = "generic_cat" return cat @my_registry.cats("int_cat.v1") def int_cat( value_in: Optional[int] = None, value_out: Optional[int] = None ) -> Cat[Optional[int], Optional[int]]: """ Instantiates cat with integer values. """ return Cat(name="int_cat", value_in=value_in, value_out=value_out) @my_registry.optimizers.register("my_cool_optimizer.v1") def make_my_optimizer(learn_rate: List[float], beta1: float): return Adam(learn_rate, beta1=beta1) @my_registry.schedules("my_cool_repetitive_schedule.v1") def decaying(base_rate: float, repeat: int) -> List[float]: return repeat * [base_rate] @contextlib.contextmanager def make_tempdir(): d = Path(tempfile.mkdtemp()) yield d shutil.rmtree(str(d)) catalogue-2.1.0/requirements.txt000066400000000000000000000002301426104053400167340ustar00rootroot00000000000000zipp>=0.5; python_version < "3.8" typing-extensions>=3.6.4; python_version < "3.8" pytest>=4.6.5 mypy pydantic types-dataclasses typing_extensions srslycatalogue-2.1.0/setup.cfg000066400000000000000000000024601426104053400153000ustar00rootroot00000000000000[metadata] version = 2.1.0 description = Lightweight function registries for your library url = https://github.com/explosion/catalogue author = Explosion author_email = contact@explosion.ai license = MIT long_description = file: README.md long_description_content_type = text/markdown classifiers = Development Status :: 5 - Production/Stable Environment :: Console Intended Audience :: Developers Intended Audience :: Science/Research License :: OSI Approved :: MIT License Operating System :: POSIX :: Linux Operating System :: MacOS :: MacOS X Operating System :: Microsoft :: Windows Programming Language :: Python :: 3 Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 Topic :: Scientific/Engineering [options] zip_safe = true include_package_data = true python_requires = >=3.6 install_requires = zipp>=0.5; python_version < "3.8" typing-extensions>=3.6.4; python_version < "3.8" [sdist] formats = gztar [flake8] ignore = E203, E266, E501, E731, W503 max-line-length = 80 select = B,C,E,F,W,T4,B9 exclude = .env, .git, __pycache__, [mypy] ignore_missing_imports = True no_implicit_optional = True catalogue-2.1.0/setup.py000066400000000000000000000002311426104053400151630ustar00rootroot00000000000000#!/usr/bin/env python if __name__ == "__main__": from setuptools import setup, find_packages setup(name="catalogue", packages=find_packages())