pax_global_header 0000666 0000000 0000000 00000000064 14261040534 0014511 g ustar 00root root 0000000 0000000 52 comment=e5fe944127cab62472a9657f0a5bb232e0303fc8
catalogue-2.1.0/ 0000775 0000000 0000000 00000000000 14261040534 0013455 5 ustar 00root root 0000000 0000000 catalogue-2.1.0/.gitignore 0000664 0000000 0000000 00000001506 14261040534 0015447 0 ustar 00root root 0000000 0000000 tmp/
.pytest_cache
.vscode
.mypy_cache
.prettierrc
.python-version
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
.env/
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
# Sphinx documentation
docs/_build/
# PyBuilder
target/
#Ipython Notebook
.ipynb_checkpoints
catalogue-2.1.0/LICENSE 0000664 0000000 0000000 00000002061 14261040534 0014461 0 ustar 00root root 0000000 0000000 MIT License
Copyright (c) 2019 ExplosionAI GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
catalogue-2.1.0/MANIFEST.in 0000664 0000000 0000000 00000000020 14261040534 0015203 0 ustar 00root root 0000000 0000000 include LICENSE
catalogue-2.1.0/README.md 0000664 0000000 0000000 00000071736 14261040534 0014752 0 ustar 00root root 0000000 0000000
# catalogue: Lightweight function registries and configurations for your library
`catalogue` is a small library that
- makes it easy to **add function (or object) registries** to your code
- offers a **configuration system** letting you conveniently describe arbitrary trees of objects.
_Function registries_ are helpful when you have objects that need to be both easily serializable and fully
customizable. Instead of passing a function into your object, you pass in an
identifier name, which the object can use to lookup the function from the
registry. This makes the object easy to serialize, because the name is a simple
string. If you instead saved the function, you'd have to use Pickle for
serialization, which has many drawbacks.
_Configuration_ is a huge challenge for machine-learning code because you may want to expose almost any
detail of any function as a hyperparameter. The setting you want to expose might be arbitrarily far
down in your call stack, so it might need to pass all the way through the CLI or REST API,
through any number of intermediate functions, affecting the interface of everything along the way.
And then once those settings are added, they become hard to remove later. Default values also
become hard to change without breaking backwards compatibility.
To solve this problem, `catalogue` offers a config system that lets you easily describe arbitrary trees of objects.
The objects can be created via function calls you register using a simple decorator syntax. You can even version the
functions you create, allowing you to make improvements without breaking backwards compatibility. The most similar
config system we’re aware of is [Gin](https://github.com/google/gin-config), which uses a similar syntax, and also allows you to link the configuration system
to functions in your code using a decorator. `catalogue`'s config system is simpler and emphasizes a different workflow via a
subset of Gin’s functionality.
[](https://dev.azure.com/explosion-ai/public/_build?definitionId=14)
[](https://github.com/explosion/catalogue/releases)
[](https://pypi.org/project/catalogue/)
[](https://anaconda.org/conda-forge/catalogue)
[](https://github.com/ambv/black)
## ⏳ Installation
```bash
pip install catalogue
```
```bash
conda install -c conda-forge catalogue
```
> ⚠️ **Important note:** `catalogue` v2.0+ is only compatible with Python 3.6+.
> For Python 2.7+ compatibility, use `catalogue` v1.x.
## 👩💻 Usage
### Function registry
Let's imagine you're developing a Python package that needs to load data
somewhere. You've already implemented some loader functions for the most common
data types, but you want to allow the user to easily add their own. Using
`catalogue.create` you can create a new registry under the namespace
`your_package` → `loaders`.
```python
# YOUR PACKAGE
import catalogue
loaders = catalogue.create("your_package", "loaders")
```
This gives you a `loaders.register` decorator that your users can import and
decorate their custom loader functions with.
```python
# USER CODE
from your_package import loaders
@loaders.register("custom_loader")
def custom_loader(data):
# Load something here...
return data
```
The decorated function will be registered automatically and in your package,
you'll be able to access all loaders by calling `loaders.get_all`.
```python
# YOUR PACKAGE
def load_data(data, loader_id):
print("All loaders:", loaders.get_all()) # {"custom_loader": }
loader = loaders.get(loader_id)
return loader(data)
```
The user can now refer to their custom loader using only its string name
(`"custom_loader"`) and your application will know what to do and will use their
custom function.
```python
# USER CODE
from your_package import load_data
load_data(data, loader_id="custom_loader")
```
### Configurations
The configuration system parses a `.cfg` file like
```ini
[training]
patience = 10
dropout = 0.2
use_vectors = false
[training.logging]
level = "INFO"
[nlp]
# This uses the value of training.use_vectors
use_vectors = ${training.use_vectors}
lang = "en"
```
and resolves it to a `Dict`:
```json
{
"training": {
"patience": 10,
"dropout": 0.2,
"use_vectors": false,
"logging": {
"level": "INFO"
}
},
"nlp": {
"use_vectors": false,
"lang": "en"
}
}
```
The config is divided into sections, with the section name in square brackets – for
example, `[training]`. Within the sections, config values can be assigned to keys using `=`. Values can also be referenced
from other sections using the dot notation and placeholders indicated by the dollar sign and curly braces. For example,
`${training.use_vectors}` will receive the value of use_vectors in the training block. This is useful for settings that
are shared across components.
The config format has three main differences from Python’s built-in configparser:
1. JSON-formatted values. `catalogue` passes all values through `json.loads` to interpret them. You can use atomic
values like strings, floats, integers or booleans, or you can use complex objects such as lists or maps.
2. Structured sections. `catalogue` uses a dot notation to build nested sections. If you have a section named
`[section.subsection]`, `catalogue` will parse that into a nested structure, placing subsection within section.
3. References to registry functions. If a key starts with `@`, `catalogue` will interpret its value as the name of a
function registry, load the function registered for that name and pass in the rest of the block as arguments. If type
hints are available on the function, the argument values (and return value of the function) will be validated against
them. This lets you express complex configurations, like a training pipeline where `batch_size` is populated by a
function that yields floats.
There’s no pre-defined scheme you have to follow; how you set up the top-level sections is up to you. At the end of
it, you’ll receive a dictionary with the values that you can use in your script – whether it’s complete initialized
functions, or just basic settings.
For instance, let’s say you want to define a new optimizer. You'd define its arguments in `config.cfg` like so:
```ini
[optimizer]
@optimizers = "my_cool_optimizer.v1"
learn_rate = 0.001
gamma = 1e-8
```
To load and parse this configuration:
```python
import dataclasses
from typing import Union, Iterable
from catalogue import catalogue_registry, Config
# Create a new registry.
catalogue_registry.create("optimizers")
# Define a dummy optimizer class.
@dataclasses.dataclass
class MyCoolOptimizer:
learn_rate: float
gamma: float
@catalogue_registry.optimizers.register("my_cool_optimizer.v1")
def make_my_optimizer(learn_rate: Union[float, Iterable[float]], gamma: float):
return MyCoolOptimizer(learn_rate, gamma)
# Load the config file from disk, resolve it and fetch the instantiated optimizer object.
config = Config().from_disk("./config.cfg")
resolved = catalogue_registry.resolve(config)
optimizer = resolved["optimizer"] # MyCoolOptimizer(learn_rate=0.001, gamma=1e-08)
```
Under the hood, `catalogue` will look up the `"my_cool_optimizer.v1"` function in the "optimizers" registry and then
call it with the arguments `learn_rate` and `gamma`. If the function has type annotations, it will also validate the
input. For instance, if `learn_rate` is annotated as a float and the config defines a string, `catalogue` will raise an
error.
The Thinc documentation offers further information on the configuration system:
- [recursive blocks](https://thinc.ai/docs/usage-config#registry-recursive)
- [defining variable positional arguments](https://thinc.ai/docs/usage-config#registries-args)
- [using interpolation](https://thinc.ai/docs/usage-config#config-interpolation)
- [using custom registries](https://thinc.ai/docs/usage-config#registries-custom)
- [advanced type annotations with Pydantic](https://thinc.ai/docs/usage-config#advanced-types)
- [using base schemas](https://thinc.ai/docs/usage-config#advanced-types-base-schema)
- [filling a configuration with defaults](https://thinc.ai/docs/usage-config#advanced-types-fill-defaults)
## ❓ FAQ
#### But can't the user just pass in the `custom_loader` function directly?
Sure, that's the more classic callback approach. Instead of a string ID,
`load_data` could also take a function, in which case you wouldn't need a
package like this. `catalogue` helps you when you need to produce a serializable
record of which functions were passed in. For instance, you might want to write
a log message, or save a config to load back your object later. With
`catalogue`, your functions can be parameterized by strings, so logging and
serialization remains easy – while still giving you full extensibility.
#### How do I make sure all of the registration decorators have run?
Decorators normally run when modules are imported. Relying on this side-effect
can sometimes lead to confusion, especially if there's no other reason the
module would be imported. One solution is to use
[entry points](https://packaging.python.org/specifications/entry-points/).
For instance, in [spaCy](https://spacy.io) we're starting to use function
registries to make the pipeline components much more customizable. Let's say one
user, Jo, develops a better tagging model using new machine learning research.
End-users of Jo's package should be able to write
`spacy.load("jo_tagging_model")`. They shouldn't need to remember to write
`import jos_tagged_model` first, just to run the function registries as a
side-effect. With entry points, the registration happens at install time – so
you don't need to rely on the import side-effects.
## 🎛 API
### Registry
#### function `catalogue.create`
Create a new registry for a given namespace. Returns a setter function that can
be used as a decorator or called with a name and `func` keyword argument. If
`entry_points=True` is set, the registry will check for
[Python entry points](https://packaging.python.org/tutorials/packaging-projects/#entry-points)
advertised for the given namespace, e.g. the entry point group
`spacy_architectures` for the namespace `"spacy", "architectures"`, in
`Registry.get` and `Registry.get_all`. This allows other packages to
auto-register functions.
| Argument | Type | Description |
| -------------- | ---------- | ---------------------------------------------------------------------------------------------- |
| `*namespace` | `str` | The namespace, e.g. `"spacy"` or `"spacy", "architectures"`. |
| `entry_points` | `bool` | Whether to check for entry points of the given namespace and pre-populate the global registry. |
| **RETURNS** | `Registry` | The `Registry` object with methods to register and retrieve functions. |
```python
architectures = catalogue.create("spacy", "architectures")
# Use as decorator
@architectures.register("custom_architecture")
def custom_architecture():
pass
# Use as regular function
architectures.register("custom_architecture", func=custom_architecture)
```
#### class `Registry`
The registry object that can be used to register and retrieve functions. It's
usually created internally when you call `catalogue.create`.
##### method `Registry.__init__`
Initialize a new registry. If `entry_points=True` is set, the registry will
check for
[Python entry points](https://packaging.python.org/tutorials/packaging-projects/#entry-points)
advertised for the given namespace, e.g. the entry point group
`spacy_architectures` for the namespace `"spacy", "architectures"`, in
`Registry.get` and `Registry.get_all`.
| Argument | Type | Description |
| -------------- | ------------ | -------------------------------------------------------------------------------- |
| `namespace` | `Tuple[str]` | The namespace, e.g. `"spacy"` or `"spacy", "architectures"`. |
| `entry_points` | `bool` | Whether to check for entry points of the given namespace in `get` and `get_all`. |
| **RETURNS** | `Registry` | The newly created object. |
```python
# User-facing API
architectures = catalogue.create("spacy", "architectures")
# Internal API
architectures = Registry(("spacy", "architectures"))
```
##### method `Registry.__contains__`
Check whether a name is in the registry.
| Argument | Type | Description |
| ----------- | ------ | ------------------------------------ |
| `name` | `str` | The name to check. |
| **RETURNS** | `bool` | Whether the name is in the registry. |
```python
architectures = catalogue.create("spacy", "architectures")
@architectures.register("custom_architecture")
def custom_architecture():
pass
assert "custom_architecture" in architectures
```
##### method `Registry.__call__`
Register a function in the registry's namespace. Can be used as a decorator or
called as a function with the `func` keyword argument supplying the function to
register. Delegates to `Registry.register`.
##### method `Registry.register`
Register a function in the registry's namespace. Can be used as a decorator or
called as a function with the `func` keyword argument supplying the function to
register.
| Argument | Type | Description |
| ----------- | ---------- | --------------------------------------------------------- |
| `name` | `str` | The name to register under the namespace. |
| `func` | `Any` | Optional function to register (if not used as decorator). |
| **RETURNS** | `Callable` | The decorator that takes one argument, the name. |
```python
architectures = catalogue.create("spacy", "architectures")
# Use as decorator
@architectures.register("custom_architecture")
def custom_architecture():
pass
# Use as regular function
architectures.register("custom_architecture", func=custom_architecture)
```
##### method `Registry.get`
Get a function registered in the namespace.
| Argument | Type | Description |
| ----------- | ----- | ------------------------ |
| `name` | `str` | The name. |
| **RETURNS** | `Any` | The registered function. |
```python
custom_architecture = architectures.get("custom_architecture")
```
##### method `Registry.get_all`
Get all functions in the registry's namespace.
| Argument | Type | Description |
| ----------- | ---------------- | ---------------------------------------- |
| **RETURNS** | `Dict[str, Any]` | The registered functions, keyed by name. |
```python
all_architectures = architectures.get_all()
# {"custom_architecture": }
```
##### method `Registry.get_entry_points`
Get registered entry points from other packages for this namespace. The name of
the entry point group is the namespace joined by `_`.
| Argument | Type | Description |
| ----------- | ---------------- | --------------------------------------- |
| **RETURNS** | `Dict[str, Any]` | The loaded entry points, keyed by name. |
```python
architectures = catalogue.create("spacy", "architectures", entry_points=True)
# Will get all entry points of the group "spacy_architectures"
all_entry_points = architectures.get_entry_points()
```
##### method `Registry.get_entry_point`
Check if registered entry point is available for a given name in the namespace
and load it. Otherwise, return the default value.
| Argument | Type | Description |
| ----------- | ----- | ------------------------------------------------ |
| `name` | `str` | Name of entry point to load. |
| `default` | `Any` | The default value to return. Defaults to `None`. |
| **RETURNS** | `Any` | The loaded entry point or the default value. |
```python
architectures = catalogue.create("spacy", "architectures", entry_points=True)
# Will get entry point "custom_architecture" of the group "spacy_architectures"
custom_architecture = architectures.get_entry_point("custom_architecture")
```
##### method `Registry.find`
Find the information about a registered function, including the
module and path to the file it's defined in, the line number and the
docstring, if available.
| Argument | Type | Description |
| ----------- | ---------------------------- | ----------------------------------- |
| `name` | `str` | Name of the registered function. |
| **RETURNS** | `Dict[str, Union[str, int]]` | The information about the function. |
```python
import catalogue
architectures = catalogue.create("spacy", "architectures", entry_points=True)
@architectures("my_architecture")
def my_architecture():
"""This is an architecture"""
pass
info = architectures.find("my_architecture")
# {'module': 'your_package.architectures',
# 'file': '/path/to/your_package/architectures.py',
# 'line_no': 5,
# 'docstring': 'This is an architecture'}
```
#### function `catalogue.check_exists`
Check if a namespace exists.
| Argument | Type | Description |
| ------------ | ------ | ------------------------------------------------------------ |
| `*namespace` | `str` | The namespace, e.g. `"spacy"` or `"spacy", "architectures"`. |
| **RETURNS** | `bool` | Whether the namespace exists. |
### Config
#### class `Config`
This class holds the model and training [configuration](https://thinc.ai/docs/usage-config) and can load and save the
INI-style configuration format from/to a string, file or bytes. The `Config` class is a subclass of `dict` and uses
Python’s `ConfigParser` under the hood.
##### method `Config.__init__`
Initialize a new `Config` object with optional data.
```python
from catalogue import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
```
| Argument | Type | Description |
| ----------------- | ----------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `data` | `Optional[Union[Dict[str, Any], Config]]` | Optional data to initialize the config with. |
| `section_order` | `Optional[List[str]]` | Top-level section names, in order, used to sort the saved and loaded config. All other sections will be sorted alphabetically. |
| `is_interpolated` | `Optional[bool]` | Whether the config is interpolated or whether it contains variables. Read from the `data` if it’s an instance of `Config` and otherwise defaults to `True`. |
##### method `Config.from_str`
Load the config from a string.
```python
from catalogue import Config
config_str = """
[training]
patience = 10
dropout = 0.2
"""
config = Config().from_str(config_str)
print(config["training"]) # {'patience': 10, 'dropout': 0.2}}
```
| Argument | Type | Description |
| ------------- | ---------------- | -------------------------------------------------------------------------------------------------------------------- |
| `text` | `str` | The string config to load. |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| `overrides` | `Dict[str, Any]` | Overrides for values and sections. Keys are provided in dot notation, e.g. `"training.dropout"` mapped to the value. |
| **RETURNS** | `Config` | The loaded config. |
##### method `Config.to_str`
Load the config from a string.
```python
from catalogue import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
print(config.to_str()) # '[training]\npatience = 10\n\ndropout = 0.2'
```
| Argument | Type | Description |
| ------------- | ------ | --------------------------------------------------------------------------- |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| **RETURNS** | `str` | The string config. |
##### method `Config.to_bytes`
Serialize the config to a byte string.
```python
from catalogue import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
config_bytes = config.to_bytes()
print(config_bytes) # b'[training]\npatience = 10\n\ndropout = 0.2'
```
| Argument | Type | Description |
| ------------- | ---------------- | -------------------------------------------------------------------------------------------------------------------- |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| `overrides` | `Dict[str, Any]` | Overrides for values and sections. Keys are provided in dot notation, e.g. `"training.dropout"` mapped to the value. |
| **RETURNS** | `str` | The serialized config. |
##### method `Config.from_bytes`
Load the config from a byte string.
```python
from catalogue import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
config_bytes = config.to_bytes()
new_config = Config().from_bytes(config_bytes)
```
| Argument | Type | Description |
| ------------- | -------- | --------------------------------------------------------------------------- |
| `bytes_data` | `bool` | The data to load. |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| **RETURNS** | `Config` | The loaded config. |
##### method `Config.to_disk`
Serialize the config to a file.
```python
from catalogue import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
config.to_disk("./config.cfg")
```
| Argument | Type | Description |
| ------------- | ------------------ | --------------------------------------------------------------------------- |
| `path` | `Union[Path, str]` | The file path. |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
##### method `Config.from_disk`
Load the config from a file.
```python
from catalogue import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
config.to_disk("./config.cfg")
new_config = Config().from_disk("./config.cfg")
```
| Argument | Type | Description |
| ------------- | ------------------ | -------------------------------------------------------------------------------------------------------------------- |
| `path` | `Union[Path, str]` | The file path. |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| `overrides` | `Dict[str, Any]` | Overrides for values and sections. Keys are provided in dot notation, e.g. `"training.dropout"` mapped to the value. |
| **RETURNS** | `Config` | The loaded config. |
##### method `Config.copy`
Deep-copy the config.
| Argument | Type | Description |
| ----------- | -------- | ------------------ |
| **RETURNS** | `Config` | The copied config. |
##### method `Config.interpolate`
Interpolate variables like `${section.value}` or `${section.subsection}` and return a copy of the config with interpolated
values. Can be used if a config is loaded with `interpolate=False`, e.g. via `Config.from_str`.
```python
from catalogue import Config
config_str = """
[hyper_params]
dropout = 0.2
[training]
dropout = ${hyper_params.dropout}
"""
config = Config().from_str(config_str, interpolate=False)
print(config["training"]) # {'dropout': '${hyper_params.dropout}'}}
config = config.interpolate()
print(config["training"]) # {'dropout': 0.2}}
```
| Argument | Type | Description |
| ----------- | -------- | ---------------------------------------------- |
| **RETURNS** | `Config` | A copy of the config with interpolated values. |
##### method `Config.merge`
Deep-merge two config objects, using the current config as the default. Only merges sections and dictionaries and not
other values like lists. Values that are provided in the updates are overwritten in the base config, and any new values
or sections are added. If a config value is a variable like `${section.key}` (e.g. if the config was loaded with
`interpolate=False)`, **the variable is preferred**, even if the updates provide a different value. This ensures that variable
references aren’t destroyed by a merge.
> :warning: Note that blocks that refer to registered functions using the `@` syntax are only merged if they are
> referring to the same functions. Otherwise, merging could easily produce invalid configs, since different functions
> can take different arguments. If a block refers to a different function, it’s overwritten.
```python
from catalogue import Config
base_config_str = """
[training]
patience = 10
dropout = 0.2
"""
update_config_str = """
[training]
dropout = 0.1
max_epochs = 2000
"""
base_config = Config().from_str(base_config_str)
update_config = Config().from_str(update_config_str)
merged = Config(base_config).merge(update_config)
print(merged["training"]) # {'patience': 10, 'dropout': 1.0, 'max_epochs': 2000}
```
| Argument | Type | Description |
| ----------- | ------------------------------- | --------------------------------------------------- |
| `overrides` | `Union[Dict[str, Any], Config]` | The updates to merge into the config. |
| **RETURNS** | `Config` | A new config instance containing the merged config. |
#### Config Attributes
| Argument | Type | Description |
| ----------------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `is_interpolated` | `bool` | Whether the config values have been interpolated. Defaults to `True` and is set to `False` if a config is loaded with `interpolate=False`, e.g. using `Config.from_str`. |
catalogue-2.1.0/azure-pipelines.yml 0000664 0000000 0000000 00000004314 14261040534 0017316 0 ustar 00root root 0000000 0000000 trigger:
batch: true
branches:
include:
- '*'
jobs:
- job: 'Test'
strategy:
matrix:
Python36Linux:
imageName: 'ubuntu-latest'
python.version: '3.6'
Python36Windows:
imageName: 'windows-2019'
python.version: '3.6'
Python36Mac:
imageName: 'macos-10.15'
python.version: '3.6'
Python37Linux:
imageName: 'ubuntu-latest'
python.version: '3.7'
Python37Windows:
imageName: 'windows-latest'
python.version: '3.7'
Python37Mac:
imageName: 'macos-latest'
python.version: '3.7'
Python38Linux:
imageName: 'ubuntu-latest'
python.version: '3.8'
Python38Windows:
imageName: 'windows-latest'
python.version: '3.8'
Python38Mac:
imageName: 'macos-latest'
python.version: '3.8'
Python39Linux:
imageName: 'ubuntu-latest'
python.version: '3.9'
Python39Windows:
imageName: 'windows-latest'
python.version: '3.9'
Python39Mac:
imageName: 'macos-latest'
python.version: '3.9'
Python310Linux:
imageName: 'ubuntu-latest'
python.version: '3.10'
Python310Windows:
imageName: 'windows-latest'
python.version: '3.10'
Python310Mac:
imageName: 'macos-latest'
python.version: '3.10'
maxParallel: 4
pool:
vmImage: $(imageName)
steps:
- task: UsePythonVersion@0
inputs:
versionSpec: '$(python.version)'
architecture: 'x64'
- script: |
pip install -U -r requirements.txt
pip install numpy pathy
python setup.py sdist
displayName: 'Build sdist'
- script: python -m mypy catalogue
displayName: 'Run mypy'
- task: DeleteFiles@1
inputs:
contents: 'catalogue'
displayName: 'Delete source directory'
- bash: |
SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1)
pip install dist/$SDIST
displayName: 'Install from sdist'
- script: python -m pytest --pyargs catalogue
displayName: 'Run tests'
- bash: |
pip install hypothesis
python -c "import catalogue; import hypothesis"
displayName: 'Test for conflicts'
catalogue-2.1.0/bin/ 0000775 0000000 0000000 00000000000 14261040534 0014225 5 ustar 00root root 0000000 0000000 catalogue-2.1.0/bin/push-tags.sh 0000775 0000000 0000000 00000000537 14261040534 0016504 0 ustar 00root root 0000000 0000000 #!/usr/bin/env bash
set -e
# Insist repository is clean
git diff-index --quiet HEAD
git checkout $1
git pull origin $1
git push origin $1
version=$(grep "version = " setup.cfg)
version=${version/version = }
version=${version/\'/}
version=${version/\'/}
version=${version/\"/}
version=${version/\"/}
git tag "v$version"
git push origin "v$version" catalogue-2.1.0/catalogue/ 0000775 0000000 0000000 00000000000 14261040534 0015421 5 ustar 00root root 0000000 0000000 catalogue-2.1.0/catalogue/__init__.py 0000664 0000000 0000000 00000000100 14261040534 0017521 0 ustar 00root root 0000000 0000000 from catalogue.registry import *
from catalogue.config import *
catalogue-2.1.0/catalogue/_importlib_metadata/ 0000775 0000000 0000000 00000000000 14261040534 0021421 5 ustar 00root root 0000000 0000000 catalogue-2.1.0/catalogue/_importlib_metadata/LICENSE 0000664 0000000 0000000 00000001073 14261040534 0022427 0 ustar 00root root 0000000 0000000 Copyright 2017-2019 Jason R. Coombs, Barry Warsaw
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
catalogue-2.1.0/catalogue/_importlib_metadata/__init__.py 0000664 0000000 0000000 00000047354 14261040534 0023547 0 ustar 00root root 0000000 0000000 import os
import re
import abc
import csv
import sys
import zipp
import email
import pathlib
import operator
import functools
import itertools
import posixpath
import collections
from ._compat import (
NullFinder,
PyPy_repr,
install,
Protocol,
)
from configparser import ConfigParser
from contextlib import suppress
from importlib import import_module
from importlib.abc import MetaPathFinder
from itertools import starmap
from typing import Any, List, Mapping, TypeVar, Union
__all__ = [
'Distribution',
'DistributionFinder',
'PackageNotFoundError',
'distribution',
'distributions',
'entry_points',
'files',
'metadata',
'requires',
'version',
]
class PackageNotFoundError(ModuleNotFoundError):
"""The package was not found."""
def __str__(self):
tmpl = "No package metadata was found for {self.name}"
return tmpl.format(**locals())
@property
def name(self):
(name,) = self.args
return name
class EntryPoint(
PyPy_repr, collections.namedtuple('EntryPointBase', 'name value group')
):
"""An entry point as defined by Python packaging conventions.
See `the packaging docs on entry points
`_
for more information.
"""
pattern = re.compile(
r'(?P[\w.]+)\s*'
r'(:\s*(?P[\w.]+))?\s*'
r'(?P\[.*\])?\s*$'
)
"""
A regular expression describing the syntax for an entry point,
which might look like:
- module
- package.module
- package.module:attribute
- package.module:object.attribute
- package.module:attr [extra1, extra2]
Other combinations are possible as well.
The expression is lenient about whitespace around the ':',
following the attr, and following any extras.
"""
def load(self):
"""Load the entry point from its definition. If only a module
is indicated by the value, return that module. Otherwise,
return the named object.
"""
match = self.pattern.match(self.value)
module = import_module(match.group('module'))
attrs = filter(None, (match.group('attr') or '').split('.'))
return functools.reduce(getattr, attrs, module)
@property
def module(self):
match = self.pattern.match(self.value)
return match.group('module')
@property
def attr(self):
match = self.pattern.match(self.value)
return match.group('attr')
@property
def extras(self):
match = self.pattern.match(self.value)
return list(re.finditer(r'\w+', match.group('extras') or ''))
@classmethod
def _from_config(cls, config):
return [
cls(name, value, group)
for group in config.sections()
for name, value in config.items(group)
]
@classmethod
def _from_text(cls, text):
config = ConfigParser(delimiters='=')
# case sensitive: https://stackoverflow.com/q/1611799/812183
config.optionxform = str
config.read_string(text)
return EntryPoint._from_config(config)
def __iter__(self):
"""
Supply iter so one may construct dicts of EntryPoints easily.
"""
return iter((self.name, self))
def __reduce__(self):
return (
self.__class__,
(self.name, self.value, self.group),
)
class PackagePath(pathlib.PurePosixPath):
"""A reference to a path in a package"""
def read_text(self, encoding='utf-8'):
with self.locate().open(encoding=encoding) as stream:
return stream.read()
def read_binary(self):
with self.locate().open('rb') as stream:
return stream.read()
def locate(self):
"""Return a path-like object for this path"""
return self.dist.locate_file(self)
class FileHash:
def __init__(self, spec):
self.mode, _, self.value = spec.partition('=')
def __repr__(self):
return ''.format(self.mode, self.value)
_T = TypeVar("_T")
class PackageMetadata(Protocol):
def __len__(self) -> int:
... # pragma: no cover
def __contains__(self, item: str) -> bool:
... # pragma: no cover
def __getitem__(self, key: str) -> str:
... # pragma: no cover
def get_all(self, name: str, failobj: _T = ...) -> Union[List[Any], _T]:
"""
Return all values associated with a possibly multi-valued key.
"""
class Distribution:
"""A Python distribution package."""
@abc.abstractmethod
def read_text(self, filename):
"""Attempt to load metadata file given by the name.
:param filename: The name of the file in the distribution info.
:return: The text if found, otherwise None.
"""
@abc.abstractmethod
def locate_file(self, path):
"""
Given a path to a file in this distribution, return a path
to it.
"""
@classmethod
def from_name(cls, name):
"""Return the Distribution for the given package name.
:param name: The name of the distribution package to search for.
:return: The Distribution instance (or subclass thereof) for the named
package, if found.
:raises PackageNotFoundError: When the named package's distribution
metadata cannot be found.
"""
for resolver in cls._discover_resolvers():
dists = resolver(DistributionFinder.Context(name=name))
dist = next(iter(dists), None)
if dist is not None:
return dist
else:
raise PackageNotFoundError(name)
@classmethod
def discover(cls, **kwargs):
"""Return an iterable of Distribution objects for all packages.
Pass a ``context`` or pass keyword arguments for constructing
a context.
:context: A ``DistributionFinder.Context`` object.
:return: Iterable of Distribution objects for all packages.
"""
context = kwargs.pop('context', None)
if context and kwargs:
raise ValueError("cannot accept context and kwargs")
context = context or DistributionFinder.Context(**kwargs)
return itertools.chain.from_iterable(
resolver(context) for resolver in cls._discover_resolvers()
)
@staticmethod
def at(path):
"""Return a Distribution for the indicated metadata path
:param path: a string or path-like object
:return: a concrete Distribution instance for the path
"""
return PathDistribution(pathlib.Path(path))
@staticmethod
def _discover_resolvers():
"""Search the meta_path for resolvers."""
declared = (
getattr(finder, '_catalogue_find_distributions', None) for finder in sys.meta_path
)
return filter(None, declared)
@classmethod
def _local(cls, root='.'):
from pep517 import build, meta
system = build.compat_system(root)
builder = functools.partial(
meta.build,
source_dir=root,
system=system,
)
return PathDistribution(zipp.Path(meta.build_as_zip(builder)))
@property
def metadata(self) -> PackageMetadata:
"""Return the parsed metadata for this Distribution.
The returned object will have keys that name the various bits of
metadata. See PEP 566 for details.
"""
text = (
self.read_text('METADATA')
or self.read_text('PKG-INFO')
# This last clause is here to support old egg-info files. Its
# effect is to just end up using the PathDistribution's self._path
# (which points to the egg-info file) attribute unchanged.
or self.read_text('')
)
return email.message_from_string(text)
@property
def version(self):
"""Return the 'Version' metadata for the distribution package."""
return self.metadata['Version']
@property
def entry_points(self):
return EntryPoint._from_text(self.read_text('entry_points.txt'))
@property
def files(self):
"""Files in this distribution.
:return: List of PackagePath for this distribution or None
Result is `None` if the metadata file that enumerates files
(i.e. RECORD for dist-info or SOURCES.txt for egg-info) is
missing.
Result may be empty if the metadata exists but is empty.
"""
file_lines = self._read_files_distinfo() or self._read_files_egginfo()
def make_file(name, hash=None, size_str=None):
result = PackagePath(name)
result.hash = FileHash(hash) if hash else None
result.size = int(size_str) if size_str else None
result.dist = self
return result
return file_lines and list(starmap(make_file, csv.reader(file_lines)))
def _read_files_distinfo(self):
"""
Read the lines of RECORD
"""
text = self.read_text('RECORD')
return text and text.splitlines()
def _read_files_egginfo(self):
"""
SOURCES.txt might contain literal commas, so wrap each line
in quotes.
"""
text = self.read_text('SOURCES.txt')
return text and map('"{}"'.format, text.splitlines())
@property
def requires(self):
"""Generated requirements specified for this Distribution"""
reqs = self._read_dist_info_reqs() or self._read_egg_info_reqs()
return reqs and list(reqs)
def _read_dist_info_reqs(self):
return self.metadata.get_all('Requires-Dist')
def _read_egg_info_reqs(self):
source = self.read_text('requires.txt')
return source and self._deps_from_requires_text(source)
@classmethod
def _deps_from_requires_text(cls, source):
section_pairs = cls._read_sections(source.splitlines())
sections = {
section: list(map(operator.itemgetter('line'), results))
for section, results in itertools.groupby(
section_pairs, operator.itemgetter('section')
)
}
return cls._convert_egg_info_reqs_to_simple_reqs(sections)
@staticmethod
def _read_sections(lines):
section = None
for line in filter(None, lines):
section_match = re.match(r'\[(.*)\]$', line)
if section_match:
section = section_match.group(1)
continue
yield locals()
@staticmethod
def _convert_egg_info_reqs_to_simple_reqs(sections):
"""
Historically, setuptools would solicit and store 'extra'
requirements, including those with environment markers,
in separate sections. More modern tools expect each
dependency to be defined separately, with any relevant
extras and environment markers attached directly to that
requirement. This method converts the former to the
latter. See _test_deps_from_requires_text for an example.
"""
def make_condition(name):
return name and 'extra == "{name}"'.format(name=name)
def parse_condition(section):
section = section or ''
extra, sep, markers = section.partition(':')
if extra and markers:
markers = '({markers})'.format(markers=markers)
conditions = list(filter(None, [markers, make_condition(extra)]))
return '; ' + ' and '.join(conditions) if conditions else ''
for section, deps in sections.items():
for dep in deps:
yield dep + parse_condition(section)
class DistributionFinder(MetaPathFinder):
"""
A MetaPathFinder capable of discovering installed distributions.
"""
class Context:
"""
Keyword arguments presented by the caller to
``distributions()`` or ``Distribution.discover()``
to narrow the scope of a search for distributions
in all DistributionFinders.
Each DistributionFinder may expect any parameters
and should attempt to honor the canonical
parameters defined below when appropriate.
"""
name = None
"""
Specific name for which a distribution finder should match.
A name of ``None`` matches all distributions.
"""
def __init__(self, **kwargs):
vars(self).update(kwargs)
@property
def path(self):
"""
The path that a distribution finder should search.
Typically refers to Python package paths and defaults
to ``sys.path``.
"""
return vars(self).get('path', sys.path)
@abc.abstractmethod
def _catalogue_find_distributions(self, context=Context()):
"""
Find distributions.
Return an iterable of all Distribution instances capable of
loading the metadata for packages matching the ``context``,
a DistributionFinder.Context instance.
"""
class FastPath:
"""
Micro-optimized class for searching a path for
children.
"""
def __init__(self, root):
self.root = str(root)
self.base = os.path.basename(self.root).lower()
def joinpath(self, child):
return pathlib.Path(self.root, child)
def children(self):
with suppress(Exception):
return os.listdir(self.root or '')
with suppress(Exception):
return self.zip_children()
return []
def zip_children(self):
zip_path = zipp.Path(self.root)
names = zip_path.root.namelist()
self.joinpath = zip_path.joinpath
return dict.fromkeys(child.split(posixpath.sep, 1)[0] for child in names)
def search(self, name):
return (
self.joinpath(child)
for child in self.children()
if name.matches(child, self.base)
)
class Prepared:
"""
A prepared search for metadata on a possibly-named package.
"""
normalized = None
suffixes = '.dist-info', '.egg-info'
exact_matches = [''][:0]
def __init__(self, name):
self.name = name
if name is None:
return
self.normalized = self.normalize(name)
self.exact_matches = [self.normalized + suffix for suffix in self.suffixes]
@staticmethod
def normalize(name):
"""
PEP 503 normalization plus dashes as underscores.
"""
return re.sub(r"[-_.]+", "-", name).lower().replace('-', '_')
@staticmethod
def legacy_normalize(name):
"""
Normalize the package name as found in the convention in
older packaging tools versions and specs.
"""
return name.lower().replace('-', '_')
def matches(self, cand, base):
low = cand.lower()
pre, ext = os.path.splitext(low)
name, sep, rest = pre.partition('-')
return (
low in self.exact_matches
or ext in self.suffixes
and (not self.normalized or name.replace('.', '_') == self.normalized)
# legacy case:
or self.is_egg(base)
and low == 'egg-info'
)
def is_egg(self, base):
normalized = self.legacy_normalize(self.name or '')
prefix = normalized + '-' if normalized else ''
versionless_egg_name = normalized + '.egg' if self.name else ''
return (
base == versionless_egg_name
or base.startswith(prefix)
and base.endswith('.egg')
)
@install
class MetadataPathFinder(NullFinder, DistributionFinder):
"""A degenerate finder for distribution packages on the file system.
This finder supplies only a find_distributions() method for versions
of Python that do not have a PathFinder find_distributions().
"""
def _catalogue_find_distributions(self, context=DistributionFinder.Context()):
"""
Find distributions.
Return an iterable of all Distribution instances capable of
loading the metadata for packages matching ``context.name``
(or all names if ``None`` indicated) along the paths in the list
of directories ``context.path``.
"""
found = self._search_paths(context.name, context.path)
return map(PathDistribution, found)
@classmethod
def _search_paths(cls, name, paths):
"""Find metadata directories in paths heuristically."""
return itertools.chain.from_iterable(
path.search(Prepared(name)) for path in map(FastPath, paths)
)
class PathDistribution(Distribution):
def __init__(self, path):
"""Construct a distribution from a path to the metadata directory.
:param path: A pathlib.Path or similar object supporting
.joinpath(), __div__, .parent, and .read_text().
"""
self._path = path
def read_text(self, filename):
with suppress(
FileNotFoundError,
IsADirectoryError,
KeyError,
NotADirectoryError,
PermissionError,
):
return self._path.joinpath(filename).read_text(encoding='utf-8')
read_text.__doc__ = Distribution.read_text.__doc__
def locate_file(self, path):
return self._path.parent / path
def distribution(distribution_name):
"""Get the ``Distribution`` instance for the named package.
:param distribution_name: The name of the distribution package as a string.
:return: A ``Distribution`` instance (or subclass thereof).
"""
return Distribution.from_name(distribution_name)
def distributions(**kwargs):
"""Get all ``Distribution`` instances in the current environment.
:return: An iterable of ``Distribution`` instances.
"""
return Distribution.discover(**kwargs)
def metadata(distribution_name) -> PackageMetadata:
"""Get the metadata for the named package.
:param distribution_name: The name of the distribution package to query.
:return: A PackageMetadata containing the parsed metadata.
"""
return Distribution.from_name(distribution_name).metadata
def version(distribution_name):
"""Get the version string for the named package.
:param distribution_name: The name of the distribution package to query.
:return: The version string for the package as defined in the package's
"Version" metadata key.
"""
return distribution(distribution_name).version
def entry_points():
"""Return EntryPoint objects for all installed packages.
:return: EntryPoint objects for all installed packages.
"""
eps = itertools.chain.from_iterable(dist.entry_points for dist in distributions())
by_group = operator.attrgetter('group')
ordered = sorted(eps, key=by_group)
grouped = itertools.groupby(ordered, by_group)
return {group: tuple(eps) for group, eps in grouped}
def files(distribution_name):
"""Return a list of files for the named package.
:param distribution_name: The name of the distribution package to query.
:return: List of files composing the distribution.
"""
return distribution(distribution_name).files
def requires(distribution_name):
"""
Return a list of requirements for the named package.
:return: An iterator of requirements, suitable for
packaging.requirement.Requirement.
"""
return distribution(distribution_name).requires
def packages_distributions() -> Mapping[str, List[str]]:
"""
Return a mapping of top-level packages to their
distributions.
>>> pkgs = packages_distributions()
>>> all(isinstance(dist, collections.abc.Sequence) for dist in pkgs.values())
True
"""
pkg_to_dist = collections.defaultdict(list)
for dist in distributions():
for pkg in (dist.read_text('top_level.txt') or '').split():
pkg_to_dist[pkg].append(dist.metadata['Name'])
return dict(pkg_to_dist)
catalogue-2.1.0/catalogue/_importlib_metadata/_compat.py 0000664 0000000 0000000 00000004546 14261040534 0023426 0 ustar 00root root 0000000 0000000 import sys
__all__ = ['install', 'NullFinder', 'PyPy_repr', 'Protocol']
try:
from typing import Protocol
except ImportError: # pragma: no cover
"""
pytest-mypy complains here because:
error: Incompatible import of "Protocol" (imported name has type
"typing_extensions._SpecialForm", local name has type "typing._SpecialForm")
"""
from typing_extensions import Protocol # type: ignore
def install(cls):
"""
Class decorator for installation on sys.meta_path.
Adds the backport DistributionFinder to sys.meta_path and
attempts to disable the finder functionality of the stdlib
DistributionFinder.
"""
sys.meta_path.append(cls())
disable_stdlib_finder()
return cls
def disable_stdlib_finder():
"""
Give the backport primacy for discovering path-based distributions
by monkey-patching the stdlib O_O.
See #91 for more background for rationale on this sketchy
behavior.
"""
def matches(finder):
return getattr(
finder, '__module__', None
) == '_frozen_importlib_external' and hasattr(finder, '_catalogue_find_distributions')
for finder in filter(matches, sys.meta_path): # pragma: nocover
del finder._catalogue_find_distributions
class NullFinder:
"""
A "Finder" (aka "MetaClassFinder") that never finds any modules,
but may find distributions.
"""
@staticmethod
def find_spec(*args, **kwargs):
return None
# In Python 2, the import system requires finders
# to have a find_module() method, but this usage
# is deprecated in Python 3 in favor of find_spec().
# For the purposes of this finder (i.e. being present
# on sys.meta_path but having no other import
# system functionality), the two methods are identical.
find_module = find_spec
class PyPy_repr:
"""
Override repr for EntryPoint objects on PyPy to avoid __iter__ access.
Ref #97, #102.
"""
affected = hasattr(sys, 'pypy_version_info')
def __compat_repr__(self): # pragma: nocover
def make_param(name):
value = getattr(self, name)
return '{name}={value!r}'.format(**locals())
params = ', '.join(map(make_param, self._fields))
return 'EntryPoint({params})'.format(**locals())
if affected: # pragma: nocover
__repr__ = __compat_repr__
del affected
catalogue-2.1.0/catalogue/config/ 0000775 0000000 0000000 00000000000 14261040534 0016666 5 ustar 00root root 0000000 0000000 catalogue-2.1.0/catalogue/config/__init__.py 0000664 0000000 0000000 00000000026 14261040534 0020775 0 ustar 00root root 0000000 0000000 from .config import *
catalogue-2.1.0/catalogue/config/config.py 0000664 0000000 0000000 00000132727 14261040534 0020521 0 ustar 00root root 0000000 0000000 from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping
from typing import Iterable, Sequence, cast
from types import GeneratorType
from dataclasses import dataclass
from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH
from configparser import InterpolationMissingOptionError, InterpolationSyntaxError
from configparser import NoSectionError, NoOptionError, InterpolationDepthError
from configparser import ParsingError
from pathlib import Path
from pydantic import BaseModel, create_model, ValidationError, Extra
from pydantic.main import ModelMetaclass
from pydantic.fields import ModelField
import srsly
import catalogue.registry
import inspect
import io
import numpy
import copy
import re
from .util import Decorator
# Field used for positional arguments, e.g. [section.*.xyz]. The alias is
# required for the schema (shouldn't clash with user-defined arg names)
ARGS_FIELD = "*"
ARGS_FIELD_ALIAS = "VARIABLE_POSITIONAL_ARGS"
# Aliases for fields that would otherwise shadow pydantic attributes. Can be any
# string, so we're using name + space so it looks the same in error messages etc.
RESERVED_FIELDS = {"validate": "validate\u0020"}
# Internal prefix used to mark section references for custom interpolation
SECTION_PREFIX = "__SECTION__:"
# Values that shouldn't be loaded during interpolation because it'd cause
# even explicit string values to be incorrectly parsed as bools/None etc.
JSON_EXCEPTIONS = ("true", "false", "null")
# Regex to detect whether a value contains a variable
VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}")
class CustomInterpolation(ExtendedInterpolation):
def before_read(self, parser, section, option, value):
# If we're dealing with a quoted string as the interpolation value,
# make sure we load and unquote it so we don't end up with '"value"'
try:
json_value = srsly.json_loads(value)
if isinstance(json_value, str) and json_value not in JSON_EXCEPTIONS:
value = json_value
except Exception:
pass
return super().before_read(parser, section, option, value)
def before_get(self, parser, section, option, value, defaults):
# Mostly copy-pasted from the built-in configparser implementation.
L = []
self.interpolate(parser, option, L, value, section, defaults, 1)
return "".join(L)
def interpolate(self, parser, option, accum, rest, section, map, depth):
# Mostly copy-pasted from the built-in configparser implementation.
# We need to overwrite this method so we can add special handling for
# block references :( All values produced here should be strings –
# we need to wait until the whole config is interpreted anyways so
# filling in incomplete values here is pointless. All we need is the
# section reference so we can fetch it later.
rawval = parser.get(section, option, raw=True, fallback=rest)
if depth > MAX_INTERPOLATION_DEPTH:
raise InterpolationDepthError(option, section, rawval)
while rest:
p = rest.find("$")
if p < 0:
accum.append(rest)
return
if p > 0:
accum.append(rest[:p])
rest = rest[p:]
# p is no longer used
c = rest[1:2]
if c == "$":
accum.append("$")
rest = rest[2:]
elif c == "{":
# We want to treat both ${a:b} and ${a.b} the same
m = self._KEYCRE.match(rest)
if m is None:
err = f"bad interpolation variable reference {rest}"
raise InterpolationSyntaxError(option, section, err)
orig_var = m.group(1)
path = orig_var.replace(":", ".").rsplit(".", 1)
rest = rest[m.end() :]
sect = section
opt = option
try:
if len(path) == 1:
opt = parser.optionxform(path[0])
if opt in map:
v = map[opt]
else:
# We have block reference, store it as a special key
section_name = parser[parser.optionxform(path[0])]._name
v = self._get_section_name(section_name)
elif len(path) == 2:
sect = path[0]
opt = parser.optionxform(path[1])
fallback = "__FALLBACK__"
v = parser.get(sect, opt, raw=True, fallback=fallback)
# If a variable doesn't exist, try again and treat the
# reference as a section
if v == fallback:
v = self._get_section_name(parser[f"{sect}.{opt}"]._name)
else:
err = f"More than one ':' found: {rest}"
raise InterpolationSyntaxError(option, section, err)
except (KeyError, NoSectionError, NoOptionError):
raise InterpolationMissingOptionError(
option, section, rawval, orig_var
) from None
if "$" in v:
new_map = dict(parser.items(sect, raw=True))
self.interpolate(parser, opt, accum, v, sect, new_map, depth + 1)
else:
accum.append(v)
else:
err = "'$' must be followed by '$' or '{', " "found: %r" % (rest,)
raise InterpolationSyntaxError(option, section, err)
def _get_section_name(self, name: str) -> str:
"""Generate the name of a section. Note that we use a quoted string here
so we can use section references within lists and load the list as
JSON. Since section references can't be used within strings, we don't
need the quoted vs. unquoted distinction like we do for variables.
Examples (assuming section = {"foo": 1}):
- value: ${section.foo} -> value: 1
- value: "hello ${section.foo}" -> value: "hello 1"
- value: ${section} -> value: {"foo": 1}
- value: "${section}" -> value: {"foo": 1}
- value: "hello ${section}" -> invalid
"""
return f'"{SECTION_PREFIX}{name}"'
def get_configparser(interpolate: bool = True):
config = ConfigParser(interpolation=CustomInterpolation() if interpolate else None)
# Preserve case of keys: https://stackoverflow.com/a/1611877/6400719
config.optionxform = str # type: ignore
return config
class Config(dict):
"""This class holds the model and training configuration and can load and
save the TOML-style configuration format from/to a string, file or bytes.
The Config class is a subclass of dict and uses Python's ConfigParser
under the hood.
"""
is_interpolated: bool
def __init__(
self,
data: Optional[Union[Dict[str, Any], "ConfigParser", "Config"]] = None,
*,
is_interpolated: Optional[bool] = None,
section_order: Optional[List[str]] = None,
) -> None:
"""Initialize a new Config object with optional data."""
dict.__init__(self)
if data is None:
data = {}
if not isinstance(data, (dict, Config, ConfigParser)):
raise ValueError(
f"Can't initialize Config with data. Expected dict, Config or "
f"ConfigParser but got: {type(data)}"
)
# Whether the config has been interpolated. We can use this to check
# whether we need to interpolate again when it's resolved. We assume
# that a config is interpolated by default.
if is_interpolated is not None:
self.is_interpolated = is_interpolated
elif isinstance(data, Config):
self.is_interpolated = data.is_interpolated
else:
self.is_interpolated = True
if section_order is not None:
self.section_order = section_order
elif isinstance(data, Config):
self.section_order = data.section_order
else:
self.section_order = []
# Update with data
self.update(self._sort(data))
def interpolate(self) -> "Config":
"""Interpolate a config. Returns a copy of the object."""
# This is currently the most effective way because we need our custom
# to_str logic to run in order to re-serialize the values so we can
# interpolate them again. ConfigParser.read_dict will just call str()
# on all values, which isn't enough.
return Config().from_str(self.to_str())
def interpret_config(self, config: "ConfigParser") -> None:
"""Interpret a config, parse nested sections and parse the values
as JSON. Mostly used internally and modifies the config in place.
"""
self._validate_sections(config)
# Sort sections by depth, so that we can iterate breadth-first. This
# allows us to check that we're not expanding an undefined block.
get_depth = lambda item: len(item[0].split("."))
for section, values in sorted(config.items(), key=get_depth):
if section == "DEFAULT":
# Skip [DEFAULT] section so it doesn't cause validation error
continue
parts = section.split(".")
node = self
for part in parts[:-1]:
if part == "*":
node = node.setdefault(part, {})
elif part not in node:
err_title = f"Error parsing config section. Perhaps a section name is wrong?"
err = [{"loc": parts, "msg": f"Section '{part}' is not defined"}]
raise ConfigValidationError(
config=self, errors=err, title=err_title
)
else:
node = node[part]
if not isinstance(node, dict):
# Happens if both value *and* subsection were defined for a key
err = [{"loc": parts, "msg": "found conflicting values"}]
err_cfg = f"{self}\n{({part: dict(values)})}"
raise ConfigValidationError(config=err_cfg, errors=err)
# Set the default section
node = node.setdefault(parts[-1], {})
if not isinstance(node, dict):
# Happens if both value *and* subsection were defined for a key
err = [{"loc": parts, "msg": "found conflicting values"}]
err_cfg = f"{self}\n{({part: dict(values)})}"
raise ConfigValidationError(config=err_cfg, errors=err)
try:
keys_values = list(values.items())
except InterpolationMissingOptionError as e:
raise ConfigValidationError(desc=f"{e}") from None
for key, value in keys_values:
config_v = config.get(section, key)
node[key] = self._interpret_value(config_v)
self.replace_section_refs(self)
def replace_section_refs(
self, config: Union[Dict[str, Any], "Config"], parent: str = ""
) -> None:
"""Replace references to section blocks in the final config."""
for key, value in config.items():
key_parent = f"{parent}.{key}".strip(".")
if isinstance(value, dict):
self.replace_section_refs(value, parent=key_parent)
elif isinstance(value, list):
config[key] = [
self._get_section_ref(v, parent=[parent, key]) for v in value
]
else:
config[key] = self._get_section_ref(value, parent=[parent, key])
def _interpret_value(self, value: Any) -> Any:
"""Interpret a single config value."""
result = try_load_json(value)
# If value is a string and it contains a variable, use original value
# (not interpreted string, which could lead to double quotes:
# ${x.y} -> "${x.y}" -> "'${x.y}'"). Make sure to check it's a string,
# so we're not keeping lists as strings.
# NOTE: This currently can't handle uninterpolated values like [${x.y}]!
if isinstance(result, str) and VARIABLE_RE.search(value):
result = value
if isinstance(result, list):
return [self._interpret_value(v) for v in result]
return result
def _get_section_ref(self, value: Any, *, parent: List[str] = []) -> Any:
"""Get a single section reference."""
if isinstance(value, str) and value.startswith(f'"{SECTION_PREFIX}'):
value = try_load_json(value)
if isinstance(value, str) and value.startswith(SECTION_PREFIX):
parts = value.replace(SECTION_PREFIX, "").split(".")
result = self
for item in parts:
try:
result = result[item]
except (KeyError, TypeError): # This should never happen
err_title = "Error parsing reference to config section"
err_msg = f"Section '{'.'.join(parts)}' is not defined"
err = [{"loc": parts, "msg": err_msg}]
raise ConfigValidationError(
config=self, errors=err, title=err_title
) from None
return result
elif isinstance(value, str) and SECTION_PREFIX in value:
# String value references a section (either a dict or return
# value of promise). We can't allow this, since variables are
# always interpolated *before* configs are resolved.
err_desc = (
"Can't reference whole sections or return values of function "
"blocks inside a string or list\n\nYou can change your variable to "
"reference a value instead. Keep in mind that it's not "
"possible to interpolate the return value of a registered "
"function, since variables are interpolated when the config "
"is loaded, and registered functions are resolved afterwards."
)
err = [{"loc": parent, "msg": "uses section variable in string or list"}]
raise ConfigValidationError(errors=err, desc=err_desc)
return value
def copy(self) -> "Config":
"""Deepcopy the config."""
try:
config = copy.deepcopy(self)
except Exception as e:
raise ValueError(f"Couldn't deep-copy config: {e}") from e
return Config(
config,
is_interpolated=self.is_interpolated,
section_order=self.section_order,
)
def merge(
self, updates: Union[Dict[str, Any], "Config"], remove_extra: bool = False
) -> "Config":
"""Deep merge the config with updates, using current as defaults."""
defaults = self.copy()
updates = Config(updates).copy()
merged = deep_merge_configs(updates, defaults, remove_extra=remove_extra)
return Config(
merged,
is_interpolated=defaults.is_interpolated and updates.is_interpolated,
section_order=defaults.section_order,
)
def _sort(
self, data: Union["Config", "ConfigParser", Dict[str, Any]]
) -> Dict[str, Any]:
"""Sort sections using the currently defined sort order. Sort
sections by index on section order, if available, then alphabetic, and
account for subsections, which should always follow their parent.
"""
sort_map = {section: i for i, section in enumerate(self.section_order)}
sort_key = lambda x: (
sort_map.get(x[0].split(".")[0], len(sort_map)),
_mask_positional_args(x[0]),
)
return dict(sorted(data.items(), key=sort_key))
def _set_overrides(self, config: "ConfigParser", overrides: Dict[str, Any]) -> None:
"""Set overrides in the ConfigParser before config is interpreted."""
err_title = "Error parsing config overrides"
for key, value in overrides.items():
err_msg = "not a section value that can be overwritten"
err = [{"loc": key.split("."), "msg": err_msg}]
if "." not in key:
raise ConfigValidationError(errors=err, title=err_title)
section, option = key.rsplit(".", 1)
# Check for section and accept if option not in config[section]
if section not in config:
raise ConfigValidationError(errors=err, title=err_title)
config.set(section, option, try_dump_json(value, overrides))
def _validate_sections(self, config: "ConfigParser") -> None:
# If the config defines top-level properties that are not sections (e.g.
# if config was constructed from dict), those values would be added as
# [DEFAULTS] and included in *every other section*. This is usually not
# what we want and it can lead to very confusing results.
default_section = config.defaults()
if default_section:
err_title = "Found config values without a top-level section"
err_msg = "not part of a section"
err = [{"loc": [k], "msg": err_msg} for k in default_section]
raise ConfigValidationError(errors=err, title=err_title)
def from_str(
self, text: str, *, interpolate: bool = True, overrides: Dict[str, Any] = {}
) -> "Config":
"""Load the config from a string."""
config = get_configparser(interpolate=interpolate)
if overrides:
config = get_configparser(interpolate=False)
try:
config.read_string(text)
except ParsingError as e:
desc = f"Make sure the sections and values are formatted correctly.\n\n{e}"
raise ConfigValidationError(desc=desc) from None
config._sections = self._sort(config._sections)
self._set_overrides(config, overrides)
self.clear()
self.interpret_config(config)
if overrides and interpolate:
# do the interpolation. Avoids recursion because the new call from_str call will have overrides as empty
self = self.interpolate()
self.is_interpolated = interpolate
return self
def to_str(self, *, interpolate: bool = True) -> str:
"""Write the config to a string."""
flattened = get_configparser(interpolate=interpolate)
queue: List[Tuple[tuple, "Config"]] = [(tuple(), self)]
for path, node in queue:
section_name = ".".join(path)
is_kwarg = path and path[-1] != "*"
if is_kwarg and not flattened.has_section(section_name):
# Always create sections for non-'*' sections, not only if
# they have leaf entries, as we don't want to expand
# blocks that are undefined
flattened.add_section(section_name)
for key, value in node.items():
if hasattr(value, "items"):
# Reference to a function with no arguments, serialize
# inline as a dict and don't create new section
if catalogue_registry.is_promise(value) and len(value) == 1 and is_kwarg:
flattened.set(section_name, key, try_dump_json(value, node))
else:
queue.append((path + (key,), value))
else:
flattened.set(section_name, key, try_dump_json(value, node))
# Order so subsection follow parent (not all sections, then all subs etc.)
flattened._sections = self._sort(flattened._sections)
self._validate_sections(flattened)
string_io = io.StringIO()
flattened.write(string_io)
return string_io.getvalue().strip()
def to_bytes(self, *, interpolate: bool = True) -> bytes:
"""Serialize the config to a byte string."""
return self.to_str(interpolate=interpolate).encode("utf8")
def from_bytes(
self,
bytes_data: bytes,
*,
interpolate: bool = True,
overrides: Dict[str, Any] = {},
) -> "Config":
"""Load the config from a byte string."""
return self.from_str(
bytes_data.decode("utf8"), interpolate=interpolate, overrides=overrides
)
def to_disk(self, path: Union[str, Path], *, interpolate: bool = True):
"""Serialize the config to a file."""
path = Path(path) if isinstance(path, str) else path
with path.open("w", encoding="utf8") as file_:
file_.write(self.to_str(interpolate=interpolate))
def from_disk(
self,
path: Union[str, Path],
*,
interpolate: bool = True,
overrides: Dict[str, Any] = {},
) -> "Config":
"""Load config from a file."""
path = Path(path) if isinstance(path, str) else path
with path.open("r", encoding="utf8") as file_:
text = file_.read()
return self.from_str(text, interpolate=interpolate, overrides=overrides)
def _mask_positional_args(name: str) -> List[Optional[str]]:
"""Create a section name representation that masks names
of positional arguments to retain their order in sorts."""
stable_name = cast(List[Optional[str]], name.split("."))
# Remove names of sections that are a positional argument.
for i in range(1, len(stable_name)):
if stable_name[i - 1] == "*":
stable_name[i] = None
return stable_name
def try_load_json(value: str) -> Any:
"""Load a JSON string if possible, otherwise default to original value."""
try:
return srsly.json_loads(value)
except Exception:
return value
def try_dump_json(value: Any, data: Union[Dict[str, dict], Config, str] = "") -> str:
"""Dump a config value as JSON and output user-friendly error if it fails."""
# Special case if we have a variable: it's already a string so don't dump
# to preserve ${x:y} vs. "${x:y}"
if isinstance(value, str) and VARIABLE_RE.search(value):
return value
if isinstance(value, str) and value.replace(".", "", 1).isdigit():
# Work around values that are strings but numbers
value = f'"{value}"'
try:
return srsly.json_dumps(value)
except Exception as e:
err_msg = (
f"Couldn't serialize config value of type {type(value)}: {e}. Make "
f"sure all values in your config are JSON-serializable. If you want "
f"to include Python objects, use a registered function that returns "
f"the object instead."
)
raise ConfigValidationError(config=data, desc=err_msg) from e
def deep_merge_configs(
config: Union[Dict[str, Any], Config],
defaults: Union[Dict[str, Any], Config],
*,
remove_extra: bool = False,
) -> Union[Dict[str, Any], Config]:
"""Deep merge two configs."""
if remove_extra:
# Filter out values in the original config that are not in defaults
keys = list(config.keys())
for key in keys:
if key not in defaults:
del config[key]
for key, value in defaults.items():
if isinstance(value, dict):
node = config.setdefault(key, {})
if not isinstance(node, dict):
continue
value_promises = [k for k in value if k.startswith("@")]
value_promise = value_promises[0] if value_promises else None
node_promises = [k for k in node if k.startswith("@")] if node else []
node_promise = node_promises[0] if node_promises else None
# We only update the block from defaults if it refers to the same
# registered function
if (
value_promise
and node_promise
and (
value_promise in node
and node[value_promise] != value[value_promise]
)
):
continue
if node_promise and (
node_promise not in value or node[node_promise] != value[node_promise]
):
continue
defaults = deep_merge_configs(node, value, remove_extra=remove_extra)
elif key not in config:
config[key] = value
return config
class ConfigValidationError(ValueError):
def __init__(
self,
*,
config: Optional[Union[Config, Dict[str, Dict[str, Any]], str]] = None,
errors: Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]] = tuple(),
title: Optional[str] = "Config validation error",
desc: Optional[str] = None,
parent: Optional[str] = None,
show_config: bool = True,
) -> None:
"""Custom error for validating configs.
config (Union[Config, Dict[str, Dict[str, Any]], str]): The
config the validation error refers to.
errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]):
A list of errors as dicts with keys "loc" (list of strings
describing the path of the value), "msg" (validation message
to show) and optional "type" (mostly internals).
Same format as produced by pydantic's validation error (e.errors()).
title (str): The error title.
desc (str): Optional error description, displayed below the title.
parent (str): Optional parent to use as prefix for all error locations.
For example, parent "element" will result in "element -> a -> b".
show_config (bool): Whether to print the whole config with the error.
ATTRIBUTES:
config (Union[Config, Dict[str, Dict[str, Any]], str]): The config.
errors (Iterable[Dict[str, Any]]): The errors.
error_types (Set[str]): All "type" values defined in the errors, if
available. This is most relevant for the pydantic errors that define
types like "type_error.integer". This attribute makes it easy to
check if a config validation error includes errors of a certain
type, e.g. to log additional information or custom help messages.
title (str): The title.
desc (str): The description.
parent (str): The parent.
show_config (bool): Whether to show the config.
text (str): The formatted error text.
"""
self.config = config
self.errors = errors
self.title = title
self.desc = desc
self.parent = parent
self.show_config = show_config
self.error_types = set()
for error in self.errors:
err_type = error.get("type")
if err_type:
self.error_types.add(err_type)
self.text = self._format()
ValueError.__init__(self, self.text)
@classmethod
def from_error(
cls,
err: "ConfigValidationError",
title: Optional[str] = None,
desc: Optional[str] = None,
parent: Optional[str] = None,
show_config: Optional[bool] = None,
) -> "ConfigValidationError":
"""Create a new ConfigValidationError based on an existing error, e.g.
to re-raise it with different settings. If no overrides are provided,
the values from the original error are used.
err (ConfigValidationError): The original error.
title (str): Overwrite error title.
desc (str): Overwrite error description.
parent (str): Overwrite error parent.
show_config (bool): Overwrite whether to show config.
RETURNS (ConfigValidationError): The new error.
"""
return cls(
config=err.config,
errors=err.errors,
title=title if title is not None else err.title,
desc=desc if desc is not None else err.desc,
parent=parent if parent is not None else err.parent,
show_config=show_config if show_config is not None else err.show_config,
)
def _format(self) -> str:
"""Format the error message."""
loc_divider = "->"
data = []
for error in self.errors:
err_loc = f" {loc_divider} ".join([str(p) for p in error.get("loc", [])])
if self.parent:
err_loc = f"{self.parent} {loc_divider} {err_loc}"
data.append((err_loc, error.get("msg")))
result = []
if self.title:
result.append(self.title)
if self.desc:
result.append(self.desc)
if data:
result.append("\n".join([f"{entry[0]}\t{entry[1]}" for entry in data]))
if self.config and self.show_config:
result.append(f"{self.config}")
return "\n\n" + "\n".join(result)
def alias_generator(name: str) -> str:
"""Generate field aliases in promise schema."""
# Underscore fields are not allowed in model, so use alias
if name == ARGS_FIELD_ALIAS:
return ARGS_FIELD
# Auto-alias fields that shadow base model attributes
if name in RESERVED_FIELDS:
return RESERVED_FIELDS[name]
return name
def copy_model_field(field: ModelField, type_: Any) -> ModelField:
"""Copy a model field and assign a new type, e.g. to accept an Any type
even though the original value is typed differently.
"""
return ModelField(
name=field.name,
type_=type_,
class_validators=field.class_validators,
model_config=field.model_config,
default=field.default,
default_factory=field.default_factory,
required=field.required,
)
class EmptySchema(BaseModel):
class Config:
extra = "allow"
arbitrary_types_allowed = True
class _PromiseSchemaConfig:
extra = "forbid"
arbitrary_types_allowed = True
alias_generator = alias_generator
@dataclass
class Promise:
registry: str
name: str
args: List[str]
kwargs: Dict[str, Any]
class catalogue_registry(object):
@classmethod
def create(cls, registry_name: str, entry_points: bool = False) -> None:
"""Create a new custom registry."""
if hasattr(cls, registry_name):
raise ValueError(f"Registry '{registry_name}' already exists")
reg: Decorator = catalogue.registry.create(
"catalogue", registry_name, entry_points=entry_points
)
setattr(cls, registry_name, reg)
@classmethod
def has(cls, registry_name: str, func_name: str) -> bool:
"""Check whether a function is available in a registry."""
if not hasattr(cls, registry_name):
return False
reg = getattr(cls, registry_name)
return func_name in reg
@classmethod
def get(cls, registry_name: str, func_name: str) -> Callable:
"""Get a registered function from a given registry."""
if not hasattr(cls, registry_name):
raise ValueError(f"Unknown registry: '{registry_name}'")
reg = getattr(cls, registry_name)
func = reg.get(func_name)
if func is None:
raise ValueError(f"Could not find '{func_name}' in '{registry_name}'")
return func
@classmethod
def resolve(
cls,
config: Union[Config, Dict[str, Dict[str, Any]]],
*,
schema: Type[BaseModel] = EmptySchema,
overrides: Dict[str, Any] = {},
validate: bool = True,
) -> Dict[str, Any]:
resolved, _ = cls._make(
config, schema=schema, overrides=overrides, validate=validate, resolve=True
)
return resolved
@classmethod
def fill(
cls,
config: Union[Config, Dict[str, Dict[str, Any]]],
*,
schema: Type[BaseModel] = EmptySchema,
overrides: Dict[str, Any] = {},
validate: bool = True,
):
_, filled = cls._make(
config, schema=schema, overrides=overrides, validate=validate, resolve=False
)
return filled
@classmethod
def _make(
cls,
config: Union[Config, Dict[str, Dict[str, Any]]],
*,
schema: Type[BaseModel] = EmptySchema,
overrides: Dict[str, Any] = {},
resolve: bool = True,
validate: bool = True,
) -> Tuple[Dict[str, Any], Config]:
"""Unpack a config dictionary and create two versions of the config:
a resolved version with objects from the registry created recursively,
and a filled version with all references to registry functions left
intact, but filled with all values and defaults based on the type
annotations. If validate=True, the config will be validated against the
type annotations of the registered functions referenced in the config
(if available) and/or the schema (if available).
"""
# Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}}
# Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0}
if cls.is_promise(config):
err_msg = "The top-level config object can't be a reference to a registered function."
raise ConfigValidationError(config=config, errors=[{"msg": err_msg}])
# If a Config was loaded with interpolate=False, we assume it needs to
# be interpolated first, otherwise we take it at face value
is_interpolated = not isinstance(config, Config) or config.is_interpolated
section_order = config.section_order if isinstance(config, Config) else None
orig_config = config
if not is_interpolated:
config = Config(orig_config).interpolate()
filled, _, resolved = cls._fill(
config, schema, validate=validate, overrides=overrides, resolve=resolve
)
filled = Config(filled, section_order=section_order)
# Check that overrides didn't include invalid properties not in config
if validate:
cls._validate_overrides(filled, overrides)
# Merge the original config back to preserve variables if we started
# with a config that wasn't interpolated. Here, we prefer variables to
# allow auto-filling a non-interpolated config without destroying
# variable references.
if not is_interpolated:
filled = filled.merge(
Config(orig_config, is_interpolated=False), remove_extra=True
)
return dict(resolved), filled
@classmethod
def _fill(
cls,
config: Union[Config, Dict[str, Dict[str, Any]]],
schema: Type[BaseModel] = EmptySchema,
*,
validate: bool = True,
resolve: bool = True,
parent: str = "",
overrides: Dict[str, Dict[str, Any]] = {},
) -> Tuple[
Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any]
]:
"""Build three representations of the config:
1. All promises are preserved (just like config user would provide).
2. Promises are replaced by their return values. This is the validation
copy and will be parsed by pydantic. It lets us include hacks to
work around problems (e.g. handling of generators).
3. Final copy with promises replaced by their return values.
"""
filled: Dict[str, Any] = {}
validation: Dict[str, Any] = {}
final: Dict[str, Any] = {}
for key, value in config.items():
# If the field name is reserved, we use its alias for validation
v_key = RESERVED_FIELDS.get(key, key)
key_parent = f"{parent}.{key}".strip(".")
if key_parent in overrides:
value = overrides[key_parent]
config[key] = value
if cls.is_promise(value):
if key in schema.__fields__ and not resolve:
# If we're not resolving the config, make sure that the field
# expecting the promise is typed Any so it doesn't fail
# validation if it doesn't receive the function return value
field = schema.__fields__[key]
schema.__fields__[key] = copy_model_field(field, Any)
promise_schema = cls.make_promise_schema(value, resolve=resolve)
filled[key], validation[v_key], final[key] = cls._fill(
value,
promise_schema,
validate=validate,
resolve=resolve,
parent=key_parent,
overrides=overrides,
)
reg_name, func_name = cls.get_constructor(final[key])
args, kwargs = cls.parse_args(final[key])
if resolve:
# Call the function and populate the field value. We can't
# just create an instance of the type here, since this
# wouldn't work for generics / more complex custom types
getter = cls.get(reg_name, func_name)
# We don't want to try/except this and raise our own error
# here, because we want the traceback if the function fails.
getter_result = getter(*args, **kwargs)
else:
# We're not resolving and calling the function, so replace
# the getter_result with a Promise class
getter_result = Promise(
registry=reg_name, name=func_name, args=args, kwargs=kwargs
)
validation[v_key] = getter_result
final[key] = getter_result
if isinstance(validation[v_key], GeneratorType):
# If value is a generator we can't validate type without
# consuming it (which doesn't work if it's infinite – see
# schedule for examples). So we skip it.
validation[v_key] = []
elif hasattr(value, "items"):
field_type = EmptySchema
if key in schema.__fields__:
field = schema.__fields__[key]
field_type = field.type_
if not isinstance(field.type_, ModelMetaclass):
# If we don't have a pydantic schema and just a type
field_type = EmptySchema
filled[key], validation[v_key], final[key] = cls._fill(
value,
field_type,
validate=validate,
resolve=resolve,
parent=key_parent,
overrides=overrides,
)
if key == ARGS_FIELD and isinstance(validation[v_key], dict):
# If the value of variable positional args is a dict (e.g.
# created via config blocks), only use its values
validation[v_key] = list(validation[v_key].values())
final[key] = list(final[key].values())
else:
filled[key] = value
# Prevent pydantic from consuming generator if part of a union
validation[v_key] = (
value if not isinstance(value, GeneratorType) else []
)
final[key] = value
# Now that we've filled in all of the promises, update with defaults
# from schema, and validate if validation is enabled
exclude = []
if validate:
try:
result = schema.parse_obj(validation)
except ValidationError as e:
raise ConfigValidationError(
config=config, errors=e.errors(), parent=parent
) from None
else:
# Same as parse_obj, but without validation
result = schema.construct(**validation)
# If our schema doesn't allow extra values, we need to filter them
# manually because .construct doesn't parse anything
if schema.Config.extra in (Extra.forbid, Extra.ignore):
fields = schema.__fields__.keys()
exclude = [k for k in result.__fields_set__ if k not in fields]
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()])
validation.update(result.dict(exclude=exclude_validation))
filled, final = cls._update_from_parsed(validation, filled, final)
if exclude:
filled = {k: v for k, v in filled.items() if k not in exclude}
validation = {k: v for k, v in validation.items() if k not in exclude}
final = {k: v for k, v in final.items() if k not in exclude}
return filled, validation, final
@classmethod
def _update_from_parsed(
cls, validation: Dict[str, Any], filled: Dict[str, Any], final: Dict[str, Any]
):
"""Update the final result with the parsed config like converted
values recursively.
"""
for key, value in validation.items():
if key in RESERVED_FIELDS.values():
continue # skip aliases for reserved fields
if key not in filled:
filled[key] = value
if key not in final:
final[key] = value
if isinstance(value, dict):
filled[key], final[key] = cls._update_from_parsed(
value, filled[key], final[key]
)
# Update final config with parsed value if they're not equal (in
# value and in type) but not if it's a generator because we had to
# replace that to validate it correctly
elif key == ARGS_FIELD:
continue # don't substitute if list of positional args
elif isinstance(value, numpy.ndarray): # check numpy first, just in case
final[key] = value
elif (
value != final[key] or not isinstance(type(value), type(final[key]))
) and not isinstance(final[key], GeneratorType):
final[key] = value
return filled, final
@classmethod
def _validate_overrides(cls, filled: Config, overrides: Dict[str, Any]):
"""Validate overrides against a filled config to make sure there are
no references to properties that don't exist and weren't used."""
error_msg = "Invalid override: config value doesn't exist"
errors = []
for override_key in overrides.keys():
if not cls._is_in_config(override_key, filled):
errors.append({"msg": error_msg, "loc": [override_key]})
if errors:
raise ConfigValidationError(config=filled, errors=errors)
@classmethod
def _is_in_config(cls, prop: str, config: Union[Dict[str, Any], Config]):
"""Check whether a nested config property like "section.subsection.key"
is in a given config."""
tree = prop.split(".")
obj = dict(config)
while tree:
key = tree.pop(0)
if isinstance(obj, dict) and key in obj:
obj = obj[key]
else:
return False
return True
@classmethod
def is_promise(cls, obj: Any) -> bool:
"""Check whether an object is a "promise", i.e. contains a reference
to a registered function (via a key starting with `"@"`.
"""
if not hasattr(obj, "keys"):
return False
id_keys = [k for k in obj.keys() if k.startswith("@")]
if len(id_keys):
return True
return False
@classmethod
def get_constructor(cls, obj: Dict[str, Any]) -> Tuple[str, str]:
id_keys = [k for k in obj.keys() if k.startswith("@")]
if len(id_keys) != 1:
err_msg = f"A block can only contain one function registry reference. Got: {id_keys}"
raise ConfigValidationError(config=obj, errors=[{"msg": err_msg}])
else:
key = id_keys[0]
value = obj[key]
return (key[1:], value)
@classmethod
def parse_args(cls, obj: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]:
args = []
kwargs = {}
for key, value in obj.items():
if not key.startswith("@"):
if key == ARGS_FIELD:
args = value
elif key in RESERVED_FIELDS.values():
continue
else:
kwargs[key] = value
return args, kwargs
@classmethod
def make_promise_schema(
cls, obj: Dict[str, Any], *, resolve: bool = True
) -> Type[BaseModel]:
"""Create a schema for a promise dict (referencing a registry function)
by inspecting the function signature.
"""
reg_name, func_name = cls.get_constructor(obj)
if not resolve and not cls.has(reg_name, func_name):
return EmptySchema
func = cls.get(reg_name, func_name)
# Read the argument annotations and defaults from the function signature
id_keys = [k for k in obj.keys() if k.startswith("@")]
sig_args: Dict[str, Any] = {id_keys[0]: (str, ...)}
for param in inspect.signature(func).parameters.values():
# If no annotation is specified assume it's anything
annotation = param.annotation if param.annotation != param.empty else Any
# If no default value is specified assume that it's required
default = param.default if param.default != param.empty else ...
# Handle spread arguments and use their annotation as Sequence[whatever]
if param.kind == param.VAR_POSITIONAL:
spread_annot = Sequence[annotation] # type: ignore
sig_args[ARGS_FIELD_ALIAS] = (spread_annot, default)
else:
name = RESERVED_FIELDS.get(param.name, param.name)
sig_args[name] = (annotation, default)
sig_args["__config__"] = _PromiseSchemaConfig
return create_model("ArgModel", **sig_args)
__all__ = ["Config", "catalogue_registry", "ConfigValidationError"]
catalogue-2.1.0/catalogue/config/util.py 0000664 0000000 0000000 00000003016 14261040534 0020215 0 ustar 00root root 0000000 0000000 import contextlib
import functools
import shutil
import sys
import tempfile
from pathlib import Path
from typing import TypeVar, Callable, Any, Iterator
if sys.version_info < (3, 8):
# Ignoring type for mypy to avoid "Incompatible import" error (https://github.com/python/mypy/issues/4427).
from typing_extensions import Protocol # type: ignore
else:
from typing import Protocol
_DIn = TypeVar("_DIn")
class Decorator(Protocol):
"""Protocol to mark a function as returning its child with identical signature."""
def __call__(self, name: str) -> Callable[[_DIn], _DIn]: ...
# This is how functools.partials seems to do it, too, to retain the return type
PartialT = TypeVar("PartialT")
def partial(
func: Callable[..., PartialT], *args: Any, **kwargs: Any
) -> Callable[..., PartialT]:
"""Wrapper around functools.partial that retains docstrings and can include
other workarounds if needed.
"""
partial_func = functools.partial(func, *args, **kwargs)
partial_func.__doc__ = func.__doc__
return partial_func
class Generator(Iterator):
"""Custom generator type. Used to annotate function arguments that accept
generators so they can be validated by pydantic (which doesn't support
iterators/iterables otherwise).
"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if not hasattr(v, "__iter__") and not hasattr(v, "__next__"):
raise TypeError("not a valid iterator")
return v
catalogue-2.1.0/catalogue/registry.py 0000664 0000000 0000000 00000021006 14261040534 0017642 0 ustar 00root root 0000000 0000000 from typing import Sequence, Any, Dict, Tuple, Callable, Optional, TypeVar, Union
import inspect
try: # Python 3.8
import importlib.metadata as importlib_metadata
except ImportError:
from . import _importlib_metadata as importlib_metadata # type: ignore
# Only ever call this once for performance reasons
AVAILABLE_ENTRY_POINTS = importlib_metadata.entry_points() # type: ignore
# This is where functions will be registered
REGISTRY: Dict[Tuple[str, ...], Any] = {}
InFunc = TypeVar("InFunc")
def create(*namespace: str, entry_points: bool = False) -> "Registry":
"""Create a new registry.
*namespace (str): The namespace, e.g. "spacy" or "spacy", "architectures".
entry_points (bool): Accept registered functions from entry points.
RETURNS (Registry): The Registry object.
"""
if check_exists(*namespace):
raise RegistryError(f"Namespace already exists: {namespace}")
return Registry(namespace, entry_points=entry_points)
class Registry(object):
def __init__(self, namespace: Sequence[str], entry_points: bool = False) -> None:
"""Initialize a new registry.
namespace (Sequence[str]): The namespace.
entry_points (bool): Whether to also check for entry points.
"""
self.namespace = namespace
self.entry_point_namespace = "_".join(namespace)
self.entry_points = entry_points
def __contains__(self, name: str) -> bool:
"""Check whether a name is in the registry.
name (str): The name to check.
RETURNS (bool): Whether the name is in the registry.
"""
namespace = tuple(list(self.namespace) + [name])
has_entry_point = self.entry_points and self.get_entry_point(name)
return has_entry_point or namespace in REGISTRY
def __call__(
self, name: str, func: Optional[Any] = None
) -> Callable[[InFunc], InFunc]:
"""Register a function for a given namespace. Same as Registry.register.
name (str): The name to register under the namespace.
func (Any): Optional function to register (if not used as decorator).
RETURNS (Callable): The decorator.
"""
return self.register(name, func=func)
def register(
self, name: str, *, func: Optional[Any] = None
) -> Callable[[InFunc], InFunc]:
"""Register a function for a given namespace.
name (str): The name to register under the namespace.
func (Any): Optional function to register (if not used as decorator).
RETURNS (Callable): The decorator.
"""
def do_registration(func):
_set(list(self.namespace) + [name], func)
return func
if func is not None:
return do_registration(func)
return do_registration
def get(self, name: str) -> Any:
"""Get the registered function for a given name.
name (str): The name.
RETURNS (Any): The registered function.
"""
if self.entry_points:
from_entry_point = self.get_entry_point(name)
if from_entry_point:
return from_entry_point
namespace = list(self.namespace) + [name]
if not check_exists(*namespace):
current_namespace = " -> ".join(self.namespace)
available = ", ".join(sorted(self.get_all().keys())) or "none"
raise RegistryError(
f"Cant't find '{name}' in registry {current_namespace}. Available names: {available}"
)
return _get(namespace)
def get_all(self) -> Dict[str, Any]:
"""Get a all functions for a given namespace.
namespace (Tuple[str]): The namespace to get.
RETURNS (Dict[str, Any]): The functions, keyed by name.
"""
global REGISTRY
result = {}
if self.entry_points:
result.update(self.get_entry_points())
for keys, value in REGISTRY.copy().items():
if len(self.namespace) == len(keys) - 1 and all(
self.namespace[i] == keys[i] for i in range(len(self.namespace))
):
result[keys[-1]] = value
return result
def get_entry_points(self) -> Dict[str, Any]:
"""Get registered entry points from other packages for this namespace.
RETURNS (Dict[str, Any]): Entry points, keyed by name.
"""
result = {}
for entry_point in AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, []):
result[entry_point.name] = entry_point.load()
return result
def get_entry_point(self, name: str, default: Optional[Any] = None) -> Any:
"""Check if registered entry point is available for a given name in the
namespace and load it. Otherwise, return the default value.
name (str): Name of entry point to load.
default (Any): The default value to return.
RETURNS (Any): The loaded entry point or the default value.
"""
for entry_point in AVAILABLE_ENTRY_POINTS.get(self.entry_point_namespace, []):
if entry_point.name == name:
return entry_point.load()
return default
def find(self, name: str) -> Dict[str, Optional[Union[str, int]]]:
"""Find the information about a registered function, including the
module and path to the file it's defined in, the line number and the
docstring, if available.
name (str): Name of the registered function.
RETURNS (Dict[str, Optional[Union[str, int]]]): The function info.
"""
func = self.get(name)
module = inspect.getmodule(func)
# These calls will fail for Cython modules so we need to work around them
line_no: Optional[int] = None
file_name: Optional[str] = None
try:
_, line_no = inspect.getsourcelines(func)
file_name = inspect.getfile(func)
except (TypeError, ValueError):
pass
docstring = inspect.getdoc(func)
return {
"module": module.__name__ if module else None,
"file": file_name,
"line_no": line_no,
"docstring": inspect.cleandoc(docstring) if docstring else None,
}
def check_exists(*namespace: str) -> bool:
"""Check if a namespace exists.
*namespace (str): The namespace.
RETURNS (bool): Whether the namespace exists.
"""
return namespace in REGISTRY
def _get(namespace: Sequence[str]) -> Any:
"""Get the value for a given namespace.
namespace (Sequence[str]): The namespace.
RETURNS (Any): The value for the namespace.
"""
global REGISTRY
if not all(isinstance(name, str) for name in namespace):
raise ValueError(
f"Invalid namespace. Expected tuple of strings, but got: {namespace}"
)
namespace = tuple(namespace)
if namespace not in REGISTRY:
raise RegistryError(f"Can't get namespace {namespace} (not in registry)")
return REGISTRY[namespace]
def _get_all(namespace: Sequence[str]) -> Dict[Tuple[str, ...], Any]:
"""Get all matches for a given namespace, e.g. ("a", "b", "c") and
("a", "b") for namespace ("a", "b").
namespace (Sequence[str]): The namespace.
RETURNS (Dict[Tuple[str], Any]): All entries for the namespace, keyed
by their full namespaces.
"""
global REGISTRY
result = {}
for keys, value in REGISTRY.copy().items():
if len(namespace) <= len(keys) and all(
namespace[i] == keys[i] for i in range(len(namespace))
):
result[keys] = value
return result
def _set(namespace: Sequence[str], func: Any) -> None:
"""Set a value for a given namespace.
namespace (Sequence[str]): The namespace.
func (Callable): The value to set.
"""
global REGISTRY
REGISTRY[tuple(namespace)] = func
def _remove(namespace: Sequence[str]) -> Any:
"""Remove a value for a given namespace.
namespace (Sequence[str]): The namespace.
RETURNS (Any): The removed value.
"""
global REGISTRY
namespace = tuple(namespace)
if namespace not in REGISTRY:
raise RegistryError(f"Can't get namespace {namespace} (not in registry)")
removed = REGISTRY[namespace]
del REGISTRY[namespace]
return removed
def _empty_registry() -> None:
""" Empties REGISTRY completely. """
global REGISTRY
REGISTRY = {}
class RegistryError(ValueError):
pass
__all__ = [
"AVAILABLE_ENTRY_POINTS",
"REGISTRY",
"create",
"Registry",
"check_exists",
"_get",
"_get_all",
"_set",
"_remove",
"_empty_registry",
"RegistryError",
"importlib_metadata"
]
catalogue-2.1.0/catalogue/tests/ 0000775 0000000 0000000 00000000000 14261040534 0016563 5 ustar 00root root 0000000 0000000 catalogue-2.1.0/catalogue/tests/__init__.py 0000664 0000000 0000000 00000000000 14261040534 0020662 0 ustar 00root root 0000000 0000000 catalogue-2.1.0/catalogue/tests/conftest.py 0000664 0000000 0000000 00000000750 14261040534 0020764 0 ustar 00root root 0000000 0000000 import pytest
def pytest_addoption(parser):
parser.addoption("--slow", action="store_true", help="include slow tests")
@pytest.fixture()
def pathy_fixture():
pytest.importorskip("pathy")
import tempfile
import shutil
from pathy import use_fs, Pathy
temp_folder = tempfile.mkdtemp(prefix="thinc-pathy")
use_fs(temp_folder)
root = Pathy("gs://test-bucket")
root.mkdir(exist_ok=True)
yield root
use_fs(False)
shutil.rmtree(temp_folder)
catalogue-2.1.0/catalogue/tests/test_catalogue.py 0000664 0000000 0000000 00000012170 14261040534 0022141 0 ustar 00root root 0000000 0000000 from typing import Dict, Tuple, Any
import pytest
import sys
from pathlib import Path
import catalogue
@pytest.fixture(autouse=True)
def cleanup():
# Don't delete entries needed for config tests.
for key in set(catalogue.REGISTRY.keys()) - set(filter_registry("config").keys()):
catalogue.REGISTRY.pop(key)
yield
def filter_registry(keep: str) -> Dict[Tuple[str, ...], Any]:
"""
Filters registry objects for tests.
test_mode (str): One of ("catalogue", "config"). Only entries in the registry belonging to the corresponding tests
will be returned.
RETURNS (Dict[Tuple[str], Any]): entries in registry without those added for config tests.
"""
assert keep in ("catalogue", "config")
return {
key: val for key, val in catalogue.REGISTRY.items()
if ("config_tests" in key) is (keep == "config")
}
def test_get_set():
catalogue._set(("a", "b", "c"), "test")
assert len(filter_registry("catalogue")) == 1
assert ("a", "b", "c") in catalogue.REGISTRY
assert catalogue.check_exists("a", "b", "c")
assert catalogue.REGISTRY[("a", "b", "c")] == "test"
assert catalogue._get(("a", "b", "c")) == "test"
with pytest.raises(catalogue.RegistryError):
catalogue._get(("a", "b", "d"))
with pytest.raises(catalogue.RegistryError):
catalogue._get(("a", "b", "c", "d"))
catalogue._set(("x", "y", "z1"), "test1")
catalogue._set(("x", "y", "z2"), "test2")
assert catalogue._remove(("a", "b", "c")) == "test"
catalogue._set(("x", "y2"), "test3")
with pytest.raises(catalogue.RegistryError):
catalogue._remove(("x", "y"))
assert catalogue._remove(("x", "y", "z2")) == "test2"
def test_registry_get_set():
test_registry = catalogue.create("test")
with pytest.raises(catalogue.RegistryError):
test_registry.get("foo")
test_registry.register("foo", func=lambda x: x)
assert "foo" in test_registry
def test_registry_call():
test_registry = catalogue.create("test")
test_registry("foo", func=lambda x: x)
assert "foo" in test_registry
def test_get_all():
catalogue._set(("a", "b", "c"), "test")
catalogue._set(("a", "b", "d"), "test")
catalogue._set(("a", "b"), "test")
catalogue._set(("b", "a"), "test")
all_items = catalogue._get_all(("a", "b"))
assert len(all_items) == 3
assert ("a", "b", "c") in all_items
assert ("a", "b", "d") in all_items
assert ("a", "b") in all_items
all_items = catalogue._get_all(("a", "b", "c"))
assert len(all_items) == 1
assert ("a", "b", "c") in all_items
assert len(catalogue._get_all(("a", "b", "c", "d"))) == 0
def test_create_single_namespace():
assert filter_registry("catalogue") == {}
test_registry = catalogue.create("test")
@test_registry.register("a")
def a():
pass
def b():
pass
test_registry.register("b", func=b)
items = test_registry.get_all()
assert len(items) == 2
assert items["a"] == a
assert items["b"] == b
assert catalogue.check_exists("test", "a")
assert catalogue.check_exists("test", "b")
assert catalogue._get(("test", "a")) == a
assert catalogue._get(("test", "b")) == b
with pytest.raises(TypeError):
# The decorator only accepts one argument
@test_registry.register("x", "y")
def x():
pass
def test_create_multi_namespace():
test_registry = catalogue.create("x", "y")
@test_registry.register("z")
def z():
pass
items = test_registry.get_all()
assert len(items) == 1
assert items["z"] == z
assert catalogue.check_exists("x", "y", "z")
assert catalogue._get(("x", "y", "z")) == z
@pytest.mark.skipif(sys.version_info >= (3, 10), reason="Test is not yet updated for 3.10 importlib_metadata API")
def test_entry_points():
# Create a new EntryPoint object by pretending we have a setup.cfg and
# use one of catalogue's util functions as the advertised function
ep_string = "[options.entry_points]test_foo\n bar = catalogue.registry:check_exists"
ep = catalogue.importlib_metadata.EntryPoint._from_text(ep_string)
catalogue.AVAILABLE_ENTRY_POINTS["test_foo"] = ep
assert filter_registry("catalogue") == {}
test_registry = catalogue.create("test", "foo", entry_points=True)
entry_points = test_registry.get_entry_points()
assert "bar" in entry_points
assert entry_points["bar"] == catalogue.check_exists
assert test_registry.get_entry_point("bar") == catalogue.check_exists
assert filter_registry("catalogue") == {}
assert test_registry.get("bar") == catalogue.check_exists
assert test_registry.get_all() == {"bar": catalogue.check_exists}
assert "bar" in test_registry
def test_registry_find():
test_registry = catalogue.create("test_registry_find")
name = "a"
@test_registry.register(name)
def a():
"""This is a registered function."""
pass
info = test_registry.find(name)
assert info["module"] == "catalogue.tests.test_catalogue"
assert info["file"] == str(Path(__file__))
assert info["docstring"] == "This is a registered function."
assert info["line_no"]
catalogue-2.1.0/catalogue/tests/test_config.py 0000664 0000000 0000000 00000145365 14261040534 0021457 0 ustar 00root root 0000000 0000000 import inspect
import pytest
from typing import Dict, Optional, Iterable, Callable, Any, Union
from types import GeneratorType
import pickle
try:
import numpy
has_numpy = True
except ImportError:
has_numpy = False
from pydantic import BaseModel, StrictFloat, PositiveInt, constr
from pydantic.types import StrictBool
from catalogue import ConfigValidationError, Config
from catalogue.config.util import Generator, partial
from catalogue.tests.util import Cat, my_registry, make_tempdir
EXAMPLE_CONFIG = """
[optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
use_averages = true
[optimizer.learn_rate]
@schedules = "warmup_linear.v1"
initial_rate = 0.1
warmup_steps = 10000
total_steps = 100000
[pipeline]
[pipeline.classifier]
name = "classifier"
factory = "classifier"
[pipeline.classifier.model]
@layers = "ClassifierModel.v1"
hidden_depth = 1
hidden_width = 64
token_vector_width = 128
[pipeline.classifier.model.embedding]
@layers = "Embedding.v1"
width = ${pipeline.classifier.model:token_vector_width}
"""
OPTIMIZER_CFG = """
[optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
use_averages = true
[optimizer.learn_rate]
@schedules = "warmup_linear.v1"
initial_rate = 0.1
warmup_steps = 10000
total_steps = 100000
"""
class HelloIntsSchema(BaseModel):
hello: int
world: int
class Config:
extra = "forbid"
class DefaultsSchema(BaseModel):
required: int
optional: str = "default value"
class Config:
extra = "forbid"
class ComplexSchema(BaseModel):
outer_req: int
outer_opt: str = "default value"
level2_req: HelloIntsSchema
level2_opt: DefaultsSchema = DefaultsSchema(required=1)
good_catsie = {"@cats": "catsie.v1", "evil": False, "cute": True}
ok_catsie = {"@cats": "catsie.v1", "evil": False, "cute": False}
bad_catsie = {"@cats": "catsie.v1", "evil": True, "cute": True}
worst_catsie = {"@cats": "catsie.v1", "evil": True, "cute": False}
def test_validate_simple_config():
simple_config = {"hello": 1, "world": 2}
f, _, v = my_registry._fill(simple_config, HelloIntsSchema)
assert f == simple_config
assert v == simple_config
def test_invalidate_simple_config():
invalid_config = {"hello": 1, "world": "hi!"}
with pytest.raises(ConfigValidationError) as exc_info:
my_registry._fill(invalid_config, HelloIntsSchema)
error = exc_info.value
assert len(error.errors) == 1
assert "type_error.integer" in error.error_types
def test_invalidate_extra_args():
invalid_config = {"hello": 1, "world": 2, "extra": 3}
with pytest.raises(ConfigValidationError):
my_registry._fill(invalid_config, HelloIntsSchema)
def test_fill_defaults_simple_config():
valid_config = {"required": 1}
filled, _, v = my_registry._fill(valid_config, DefaultsSchema)
assert filled["required"] == 1
assert filled["optional"] == "default value"
invalid_config = {"optional": "some value"}
with pytest.raises(ConfigValidationError):
my_registry._fill(invalid_config, DefaultsSchema)
def test_fill_recursive_config():
valid_config = {"outer_req": 1, "level2_req": {"hello": 4, "world": 7}}
filled, _, validation = my_registry._fill(valid_config, ComplexSchema)
assert filled["outer_req"] == 1
assert filled["outer_opt"] == "default value"
assert filled["level2_req"]["hello"] == 4
assert filled["level2_req"]["world"] == 7
assert filled["level2_opt"]["required"] == 1
assert filled["level2_opt"]["optional"] == "default value"
def test_is_promise():
assert my_registry.is_promise(good_catsie)
assert not my_registry.is_promise({"hello": "world"})
assert not my_registry.is_promise(1)
invalid = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"}
assert my_registry.is_promise(invalid)
def test_get_constructor():
assert my_registry.get_constructor(good_catsie) == ("cats", "catsie.v1")
def test_parse_args():
args, kwargs = my_registry.parse_args(bad_catsie)
assert args == []
assert kwargs == {"evil": True, "cute": True}
def test_make_promise_schema():
schema = my_registry.make_promise_schema(good_catsie)
assert "evil" in schema.__fields__
assert "cute" in schema.__fields__
def test_validate_promise():
config = {"required": 1, "optional": good_catsie}
filled, _, validated = my_registry._fill(config, DefaultsSchema)
assert filled == config
assert validated == {"required": 1, "optional": "meow"}
def test_fill_validate_promise():
config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}}
filled, _, validated = my_registry._fill(config, DefaultsSchema)
assert filled["optional"]["cute"] is True
def test_fill_invalidate_promise():
config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}}
with pytest.raises(ConfigValidationError):
my_registry._fill(config, HelloIntsSchema)
config["optional"]["whiskers"] = True
with pytest.raises(ConfigValidationError):
my_registry._fill(config, DefaultsSchema)
def test_create_registry():
with pytest.raises(ValueError):
my_registry.create("cats")
my_registry.create("dogs")
assert hasattr(my_registry, "dogs")
assert len(my_registry.dogs.get_all()) == 0
my_registry.dogs.register("good_boy.v1", func=lambda x: x)
assert len(my_registry.dogs.get_all()) == 1
with pytest.raises(ValueError):
my_registry.create("dogs")
def test_registry_methods():
with pytest.raises(ValueError):
my_registry.get("dfkoofkds", "catsie.v1")
my_registry.cats.register("catsie.v123")(None)
with pytest.raises(ValueError):
my_registry.get("cats", "catsie.v123")
def test_resolve_no_schema():
config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
result = my_registry.resolve({"cfg": config})["cfg"]
assert result["one"] == 1
assert result["two"] == {"three": "scratch!"}
with pytest.raises(ConfigValidationError):
config = {"two": {"three": {"@cats": "catsie.v1", "evil": "true"}}}
my_registry.resolve(config)
def test_resolve_schema():
class TestBaseSubSchema(BaseModel):
three: str
class TestBaseSchema(BaseModel):
one: PositiveInt
two: TestBaseSubSchema
class Config:
extra = "forbid"
class TestSchema(BaseModel):
cfg: TestBaseSchema
config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
my_registry.resolve({"cfg": config}, schema=TestSchema)
config = {"one": -1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
with pytest.raises(ConfigValidationError):
# "one" is not a positive int
my_registry.resolve({"cfg": config}, schema=TestSchema)
config = {"one": 1, "two": {"four": {"@cats": "catsie.v1", "evil": True}}}
with pytest.raises(ConfigValidationError):
# "three" is required in subschema
my_registry.resolve({"cfg": config}, schema=TestSchema)
def test_resolve_schema_coerced():
class TestBaseSchema(BaseModel):
test1: str
test2: bool
test3: float
class TestSchema(BaseModel):
cfg: TestBaseSchema
config = {"test1": 123, "test2": 1, "test3": 5}
filled = my_registry.fill({"cfg": config}, schema=TestSchema)
result = my_registry.resolve({"cfg": config}, schema=TestSchema)
assert result["cfg"] == {"test1": "123", "test2": True, "test3": 5.0}
# This only affects the resolved config, not the filled config
assert filled["cfg"] == config
def test_read_config():
byte_string = EXAMPLE_CONFIG.encode("utf8")
cfg = Config().from_bytes(byte_string)
assert cfg["optimizer"]["beta1"] == 0.9
assert cfg["optimizer"]["learn_rate"]["initial_rate"] == 0.1
assert cfg["pipeline"]["classifier"]["factory"] == "classifier"
assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128
def test_optimizer_config():
cfg = Config().from_str(OPTIMIZER_CFG)
optimizer = my_registry.resolve(cfg, validate=True)["optimizer"]
assert optimizer.beta1 == 0.9
def test_config_to_str():
cfg = Config().from_str(OPTIMIZER_CFG)
assert cfg.to_str().strip() == OPTIMIZER_CFG.strip()
cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG)
assert cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_to_str_creates_intermediate_blocks():
cfg = Config({"optimizer": {"foo": {"bar": 1}}})
assert (
cfg.to_str().strip()
== """
[optimizer]
[optimizer.foo]
bar = 1
""".strip()
)
def test_config_roundtrip_bytes():
cfg = Config().from_str(OPTIMIZER_CFG)
cfg_bytes = cfg.to_bytes()
new_cfg = Config().from_bytes(cfg_bytes)
assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_roundtrip_disk():
cfg = Config().from_str(OPTIMIZER_CFG)
with make_tempdir() as path:
cfg_path = path / "config.cfg"
cfg.to_disk(cfg_path)
new_cfg = Config().from_disk(cfg_path)
assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture):
cfg = Config().from_str(OPTIMIZER_CFG)
cfg_path = pathy_fixture / "config.cfg"
cfg.to_disk(cfg_path)
new_cfg = Config().from_disk(cfg_path)
assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_to_str_invalid_defaults():
"""Test that an error is raised if a config contains top-level keys without
a section that would otherwise be interpreted as [DEFAULT] (which causes
the values to be included in *all* other sections).
"""
cfg = {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}
with pytest.raises(ConfigValidationError):
Config(cfg).to_str()
config_str = "[DEFAULT]\none = 1"
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
def test_validation_custom_types():
def complex_args(
rate: StrictFloat,
steps: PositiveInt = 10, # type: ignore
log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR",
):
return None
my_registry.create("complex")
my_registry.complex("complex.v1")(complex_args)
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"}
my_registry.resolve({"config": cfg})
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": -1, "log_level": "INFO"}
with pytest.raises(ConfigValidationError):
# steps is not a positive int
my_registry.resolve({"config": cfg})
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "none"}
with pytest.raises(ConfigValidationError):
# log_level is not a string matching the regex
my_registry.resolve({"config": cfg})
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"}
with pytest.raises(ConfigValidationError):
# top-level object is promise
my_registry.resolve(cfg)
with pytest.raises(ConfigValidationError):
# top-level object is promise
my_registry.fill(cfg)
cfg = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"}
with pytest.raises(ConfigValidationError):
# two constructors
my_registry.resolve({"config": cfg})
def test_validation_no_validate():
config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": "false"}}}
result = my_registry.resolve({"cfg": config}, validate=False)
filled = my_registry.fill({"cfg": config}, validate=False)
assert result["cfg"]["one"] == 1
assert result["cfg"]["two"] == {"three": "scratch!"}
assert filled["cfg"]["two"]["three"]["evil"] == "false"
assert filled["cfg"]["two"]["three"]["cute"] is True
def test_validation_fill_defaults():
config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}}
result = my_registry.fill(config, validate=False)
assert len(result["cfg"]["two"]) == 3
with pytest.raises(ConfigValidationError):
# Required arg "evil" is not defined
my_registry.fill(config)
config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v2", "evil": False}}}
# Fill in with new defaults
result = my_registry.fill(config)
assert len(result["cfg"]["two"]) == 4
assert result["cfg"]["two"]["evil"] is False
assert result["cfg"]["two"]["cute"] is True
assert result["cfg"]["two"]["cute_level"] == 1
def test_make_config_positional_args():
@my_registry.cats("catsie.v567")
def catsie_567(*args: Optional[str], foo: str = "bar"):
assert args[0] == "^_^"
assert args[1] == "^(*.*)^"
assert foo == "baz"
return args[0]
args = ["^_^", "^(*.*)^"]
cfg = {"config": {"@cats": "catsie.v567", "foo": "baz", "*": args}}
assert my_registry.resolve(cfg)["config"] == "^_^"
def test_make_config_positional_args_complex():
@my_registry.cats("catsie.v890")
def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]):
assert args[0] == 123
return args[0]
cfg = {"config": {"@cats": "catsie.v890", "*": [123, True, 1, False]}}
assert my_registry.resolve(cfg)["config"] == 123
cfg = {"config": {"@cats": "catsie.v890", "*": [123, "True"]}}
with pytest.raises(ConfigValidationError):
# "True" is not a valid boolean or positive int
my_registry.resolve(cfg)
def test_positional_args_to_from_string():
cfg = """[a]\nb = 1\n* = ["foo","bar"]"""
assert Config().from_str(cfg).to_str() == cfg
cfg = """[a]\nb = 1\n\n[a.*.bar]\ntest = 2\n\n[a.*.foo]\ntest = 1"""
assert Config().from_str(cfg).to_str() == cfg
@my_registry.cats("catsie.v666")
def catsie_666(*args, meow=False):
return args
cfg = """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]"""
filled = my_registry.fill(Config().from_str(cfg)).to_str()
assert filled == """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]\nmeow = false"""
resolved = my_registry.resolve(Config().from_str(cfg))
assert resolved == {"a": ("foo", "bar")}
cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\nx = 1"""
filled = my_registry.fill(Config().from_str(cfg)).to_str()
assert filled == """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\nx = 1"""
resolved = my_registry.resolve(Config().from_str(cfg))
assert resolved == {"a": ({"x": 1},)}
@my_registry.cats("catsie.v777")
def catsie_777(y: int = 1):
return "meow" * y
cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777\""""
filled = my_registry.fill(Config().from_str(cfg)).to_str()
expected = """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 1"""
assert filled == expected
cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 3"""
result = my_registry.resolve(Config().from_str(cfg))
assert result == {"a": ("meowmeowmeow",)}
def test_validation_generators_iterable():
@my_registry.optimizers("test_optimizer.v1")
def test_optimizer_v1(rate: float) -> None:
return None
@my_registry.schedules("test_schedule.v1")
def test_schedule_v1(some_value: float = 1.0) -> Iterable[float]:
while True:
yield some_value
config = {"optimizer": {"@optimizers": "test_optimizer.v1", "rate": 0.1}}
my_registry.resolve(config)
def test_validation_unset_type_hints():
"""Test that unset type hints are handled correctly (and treated as Any)."""
@my_registry.optimizers("test_optimizer.v2")
def test_optimizer_v2(rate, steps: int = 10) -> None:
return None
config = {"test": {"@optimizers": "test_optimizer.v2", "rate": 0.1, "steps": 20}}
my_registry.resolve(config)
def test_validation_bad_function():
@my_registry.optimizers("bad.v1")
def bad() -> None:
raise ValueError("This is an error in the function")
return None
@my_registry.optimizers("good.v1")
def good() -> None:
return None
# Bad function
config = {"test": {"@optimizers": "bad.v1"}}
with pytest.raises(ValueError):
my_registry.resolve(config)
# Bad function call
config = {"test": {"@optimizers": "good.v1", "invalid_arg": 1}}
with pytest.raises(ConfigValidationError):
my_registry.resolve(config)
def test_objects_from_config():
config = {
"optimizer": {
"@optimizers": "my_cool_optimizer.v1",
"beta1": 0.2,
"learn_rate": {
"@schedules": "my_cool_repetitive_schedule.v1",
"base_rate": 0.001,
"repeat": 4,
},
}
}
optimizer = my_registry.resolve(config)["optimizer"]
assert optimizer.beta1 == 0.2
assert optimizer.learn_rate == [0.001] * 4
@pytest.mark.skipif(not has_numpy, reason="needs numpy")
def test_partials_from_config():
"""Test that functions registered with partial applications are handled
correctly (e.g. initializers)."""
name = "uniform_init.v1"
cfg = {"test": {"@initializers": name, "lo": -0.2}}
func = my_registry.resolve(cfg)["test"]
assert hasattr(func, "__call__")
# The partial will still have lo as an arg, just with default
assert len(inspect.signature(func).parameters) == 3
# Make sure returned partial function has correct value set
assert inspect.signature(func).parameters["lo"].default == -0.2
# Actually call the function and verify
assert numpy.asarray(func((2, 3))).shape == (2, 3)
# Make sure validation still works
bad_cfg = {"test": {"@initializers": name, "lo": [0.5]}}
with pytest.raises(ConfigValidationError):
my_registry.resolve(bad_cfg)
bad_cfg = {"test": {"@initializers": name, "lo": -0.2, "other": 10}}
with pytest.raises(ConfigValidationError):
my_registry.resolve(bad_cfg)
def test_partials_from_config_nested():
"""Test that partial functions are passed correctly to other registered
functions that consume them (e.g. initializers -> layers)."""
def test_initializer(a: int, b: int = 1) -> int:
return a * b
@my_registry.initializers("test_initializer.v1")
def configure_test_initializer(b: int = 1) -> Callable[[int], int]:
return partial(test_initializer, b=b)
@my_registry.layers("test_layer.v1")
def test_layer(init: Callable[[int], int], c: int = 1) -> Callable[[int], int]:
return lambda x: x + init(c)
cfg = {
"@layers": "test_layer.v1",
"c": 5,
"init": {"@initializers": "test_initializer.v1", "b": 10},
}
func = my_registry.resolve({"test": cfg})["test"]
assert func(1) == 51
assert func(100) == 150
def test_validate_generator():
"""Test that generator replacement for validation in config doesn't
actually replace the returned value."""
@my_registry.schedules("test_schedule.v2")
def test_schedule():
while True:
yield 10
cfg = {"@schedules": "test_schedule.v2"}
result = my_registry.resolve({"test": cfg})["test"]
assert isinstance(result, GeneratorType)
@my_registry.optimizers("test_optimizer.v2")
def test_optimizer2(rate: Generator) -> Generator:
return rate
cfg = {
"@optimizers": "test_optimizer.v2",
"rate": {"@schedules": "test_schedule.v2"},
}
result = my_registry.resolve({"test": cfg})["test"]
assert isinstance(result, GeneratorType)
@my_registry.optimizers("test_optimizer.v3")
def test_optimizer3(schedules: Dict[str, Generator]) -> Generator:
return schedules["rate"]
cfg = {
"@optimizers": "test_optimizer.v3",
"schedules": {"rate": {"@schedules": "test_schedule.v2"}},
}
result = my_registry.resolve({"test": cfg})["test"]
assert isinstance(result, GeneratorType)
@my_registry.optimizers("test_optimizer.v4")
def test_optimizer4(*schedules: Generator) -> Generator:
return schedules[0]
def test_handle_generic_type():
"""Test that validation can handle checks against arbitrary generic
types in function argument annotations."""
cfg = {"@cats": "generic_cat.v1", "cat": {"@cats": "int_cat.v1", "value_in": 3}}
cat = my_registry.resolve({"test": cfg})["test"]
assert isinstance(cat, Cat)
assert cat.value_in == 3
assert cat.value_out is None
assert cat.name == "generic_cat"
@pytest.mark.parametrize(
"cfg",
[
"[a]\nb = 1\nc = 2\n\n[a.c]\nd = 3",
"[a]\nb = 1\n\n[a.c]\nd = 2\n\n[a.c.d]\ne = 3",
],
)
def test_handle_error_duplicate_keys(cfg):
"""This would cause very cryptic error when interpreting config.
(TypeError: 'X' object does not support item assignment)
"""
with pytest.raises(ConfigValidationError):
Config().from_str(cfg)
@pytest.mark.parametrize(
"cfg,is_valid",
[("[a]\nb = 1\n\n[a.c]\nd = 3", True), ("[a]\nb = 1\n\n[A.c]\nd = 2", False)],
)
def test_cant_expand_undefined_block(cfg, is_valid):
"""Test that you can't expand a block that hasn't been created yet. This
comes up when you typo a name, and if we allow expansion of undefined blocks,
it's very hard to create good errors for those typos.
"""
if is_valid:
Config().from_str(cfg)
else:
with pytest.raises(ConfigValidationError):
Config().from_str(cfg)
def test_fill_config_overrides():
config = {
"cfg": {
"one": 1,
"two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}},
}
}
overrides = {"cfg.two.three.evil": False}
result = my_registry.fill(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"]["evil"] is False
# Test that promises can be overwritten as well
overrides = {"cfg.two.three": 3}
result = my_registry.fill(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"] == 3
# Test that value can be overwritten with promises and that the result is
# interpreted and filled correctly
overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}}
result = my_registry.fill(config, overrides=overrides)
assert result["cfg"]["two"] is None
assert result["cfg"]["one"]["@cats"] == "catsie.v1"
assert result["cfg"]["one"]["evil"] is False
assert result["cfg"]["one"]["cute"] is True
# Overwriting with wrong types should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": 20}
my_registry.fill(config, overrides=overrides, validate=True)
# Overwriting with incomplete promises should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}}
my_registry.fill(config, overrides=overrides)
# Overrides that don't match config should raise error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": False, "two.four": True}
my_registry.fill(config, overrides=overrides, validate=True)
with pytest.raises(ConfigValidationError):
overrides = {"cfg.five": False}
my_registry.fill(config, overrides=overrides, validate=True)
def test_resolve_overrides():
config = {
"cfg": {
"one": 1,
"two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}},
}
}
overrides = {"cfg.two.three.evil": False}
result = my_registry.resolve(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"] == "meow"
# Test that promises can be overwritten as well
overrides = {"cfg.two.three": 3}
result = my_registry.resolve(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"] == 3
# Test that value can be overwritten with promises
overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}}
result = my_registry.resolve(config, overrides=overrides)
assert result["cfg"]["one"] == "meow"
assert result["cfg"]["two"] is None
# Overwriting with wrong types should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": 20}
my_registry.resolve(config, overrides=overrides, validate=True)
# Overwriting with incomplete promises should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}}
my_registry.resolve(config, overrides=overrides)
# Overrides that don't match config should raise error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": False, "cfg.two.four": True}
my_registry.resolve(config, overrides=overrides, validate=True)
with pytest.raises(ConfigValidationError):
overrides = {"cfg.five": False}
my_registry.resolve(config, overrides=overrides, validate=True)
@pytest.mark.parametrize(
"prop,expected",
[("a.b.c", True), ("a.b", True), ("a", True), ("a.e", True), ("a.b.c.d", False)],
)
def test_is_in_config(prop, expected):
config = {"a": {"b": {"c": 5, "d": 6}, "e": [1, 2]}}
assert my_registry._is_in_config(prop, config) is expected
def test_resolve_prefilled_values():
class Language(object):
def __init__(self):
...
@my_registry.optimizers("prefilled.v1")
def prefilled(nlp: Language, value: int = 10):
return (nlp, value)
# Passing an instance of Language here via the config is bad, since it
# won't serialize to a string, but we still test for it
config = {"test": {"@optimizers": "prefilled.v1", "nlp": Language(), "value": 50}}
resolved = my_registry.resolve(config, validate=True)
result = resolved["test"]
assert isinstance(result[0], Language)
assert result[1] == 50
def test_fill_config_dict_return_type():
"""Test that a registered function returning a dict is handled correctly."""
@my_registry.cats.register("catsie_with_dict.v1")
def catsie_with_dict(evil: StrictBool) -> Dict[str, bool]:
return {"not_evil": not evil}
config = {"test": {"@cats": "catsie_with_dict.v1", "evil": False}, "foo": 10}
result = my_registry.fill({"cfg": config}, validate=True)["cfg"]["test"]
assert result["evil"] is False
assert "not_evil" not in result
result = my_registry.resolve({"cfg": config}, validate=True)["cfg"]["test"]
assert result["not_evil"] is True
@pytest.mark.skipif(not has_numpy, reason="needs numpy")
def test_deepcopy_config():
config = Config({"a": 1, "b": {"c": 2, "d": 3}})
copied = config.copy()
# Same values but not same object
assert config == copied
assert config is not copied
# Check for error if value can't be pickled/deepcopied
config = Config({"a": 1, "b": numpy})
with pytest.raises(ValueError):
config.copy()
def test_config_to_str_simple_promises():
"""Test that references to function registries without arguments are
serialized inline as dict."""
config_str = """[section]\nsubsection = {"@registry":"value"}"""
config = Config().from_str(config_str)
assert config["section"]["subsection"]["@registry"] == "value"
assert config.to_str() == config_str
def test_config_from_str_invalid_section():
config_str = """[a]\nb = null\n\n[a.b]\nc = 1"""
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
config_str = """[a]\nb = null\n\n[a.b.c]\nd = 1"""
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
def test_config_to_str_order():
"""Test that Config.to_str orders the sections."""
config = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": {"g": {"h": {"i": 4, "j": 5}}}}
expected = (
"[a]\ne = 3\n\n[a.b]\nc = 1\nd = 2\n\n[f]\n\n[f.g]\n\n[f.g.h]\ni = 4\nj = 5"
)
config = Config(config)
assert config.to_str() == expected
@pytest.mark.parametrize("d", [".", ":"])
def test_config_interpolation(d):
"""Test that config values are interpolated correctly. The parametrized
value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
valid. The double {{ }} in the config strings are required to prevent the
references from being interpreted as an actual f-string variable.
"""
c_str = """[a]\nfoo = "hello"\n\n[b]\nbar = ${foo}"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}"""
assert Config().from_str(c_str)["b"]["bar"] == "hello"
c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}!"""
assert Config().from_str(c_str)["b"]["bar"] == "hello!"
c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = "${{a{d}foo}}!\""""
assert Config().from_str(c_str)["b"]["bar"] == "hello!"
c_str = f"""[a]\nfoo = 15\n\n[b]\nbar = ${{a{d}foo}}!"""
assert Config().from_str(c_str)["b"]["bar"] == "15!"
c_str = f"""[a]\nfoo = ["x", "y"]\n\n[b]\nbar = ${{a{d}foo}}"""
assert Config().from_str(c_str)["b"]["bar"] == ["x", "y"]
# Interpolation within the same section
c_str = f"""[a]\nfoo = "x"\nbar = ${{a{d}foo}}\nbaz = "${{a{d}foo}}y\""""
assert Config().from_str(c_str)["a"]["bar"] == "x"
assert Config().from_str(c_str)["a"]["baz"] == "xy"
def test_config_interpolation_lists():
# Test that lists are preserved correctly
c_str = """[a]\nb = 1\n\n[c]\nd = ["hello ${a.b}", "world"]"""
config = Config().from_str(c_str, interpolate=False)
assert config["c"]["d"] == ["hello ${a.b}", "world"]
config = config.interpolate()
assert config["c"]["d"] == ["hello 1", "world"]
c_str = """[a]\nb = 1\n\n[c]\nd = [${a.b}, "hello ${a.b}", "world"]"""
config = Config().from_str(c_str)
assert config["c"]["d"] == [1, "hello 1", "world"]
config = Config().from_str(c_str, interpolate=False)
# NOTE: This currently doesn't work, because we can't know how to JSON-load
# the uninterpolated list [${a.b}].
# assert config["c"]["d"] == ["${a.b}", "hello ${a.b}", "world"]
# config = config.interpolate()
# assert config["c"]["d"] == [1, "hello 1", "world"]
c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", ${a}]"""
config = Config().from_str(c_str)
assert config["c"]["d"] == ["hello", {"b": 1}]
c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", "hello ${a}"]"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": ["hello ${a.b}"], "y": 2}]"""
config = Config().from_str(config_str)
assert config["c"]["d"] == ["hello", {"x": ["hello 1"], "y": 2}]
config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": [${a.b}], "y": 2}]"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
@pytest.mark.parametrize("d", [".", ":"])
def test_config_interpolation_sections(d):
"""Test that config sections are interpolated correctly. The parametrized
value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
valid. The double {{ }} in the config strings are required to prevent the
references from being interpreted as an actual f-string variable.
"""
# Simple block references
c_str = """[a]\nfoo = "hello"\nbar = "world"\n\n[b]\nc = ${a}"""
config = Config().from_str(c_str)
assert config["b"]["c"] == config["a"]
# References with non-string values
c_str = f"""[a]\nfoo = "hello"\n\n[a.x]\ny = ${{a{d}b}}\n\n[a.b]\nc = 1\nd = [10]"""
config = Config().from_str(c_str)
assert config["a"]["x"]["y"] == config["a"]["b"]
# Multiple references in the same string
c_str = f"""[a]\nx = "string"\ny = 10\n\n[b]\nz = "${{a{d}x}}/${{a{d}y}}\""""
config = Config().from_str(c_str)
assert config["b"]["z"] == "string/10"
# Non-string references in string (converted to string)
c_str = f"""[a]\nx = ["hello", "world"]\n\n[b]\ny = "result: ${{a{d}x}}\""""
config = Config().from_str(c_str)
assert config["b"]["y"] == 'result: ["hello", "world"]'
# References to sections referencing sections
c_str = """[a]\nfoo = "x"\n\n[b]\nbar = ${a}\n\n[c]\nbaz = ${b}"""
config = Config().from_str(c_str)
assert config["b"]["bar"] == config["a"]
assert config["c"]["baz"] == config["b"]
# References to section values referencing other sections
c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b{d}bar}}"""
config = Config().from_str(c_str)
assert config["c"]["baz"] == config["b"]["bar"]
# References to sections with subsections
c_str = """[a]\nfoo = "x"\n\n[a.b]\nbar = 100\n\n[c]\nbaz = ${a}"""
config = Config().from_str(c_str)
assert config["c"]["baz"] == config["a"]
# Infinite recursion
c_str = """[a]\nfoo ="x"\n\n[a.b]\nbar = ${a}"""
config = Config().from_str(c_str)
assert config["a"]["b"]["bar"] == config["a"]
c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b.bar{d}foo}}"""
# We can't reference not-yet interpolated subsections
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
# Generally invalid references
c_str = f"""[a]\nfoo = ${{b{d}bar}}"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
# We can't reference sections or promises within strings
c_str = """[a]\n\n[a.b]\nfoo = "x: ${c}"\n\n[c]\nbar = 1\nbaz = 2"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
def test_config_from_str_overrides():
config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = 3\n\n[f]\ng = {"x": "y"}"""
# Basic value substitution
overrides = {"a.b": 10, "a.c.d": 20}
config = Config().from_str(config_str, overrides=overrides)
assert config["a"]["b"] == 10
assert config["a"]["c"]["d"] == 20
assert config["a"]["c"]["e"] == 3
# Valid values that previously weren't in config
config = Config().from_str(config_str, overrides={"a.c.f": 100})
assert config["a"]["c"]["d"] == 2
assert config["a"]["c"]["e"] == 3
assert config["a"]["c"]["f"] == 100
# Invalid keys and sections
with pytest.raises(ConfigValidationError):
Config().from_str(config_str, overrides={"f": 10})
# This currently isn't expected to work, because the dict in f.g is not
# interpreted as a section while the config is still just the configparser
with pytest.raises(ConfigValidationError):
Config().from_str(config_str, overrides={"f.g.x": "z"})
# With variables (values)
config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = ${a:b}"""
config = Config().from_str(config_str, overrides={"a.b": 10})
assert config["a"]["b"] == 10
assert config["a"]["c"]["e"] == 10
# With variables (sections)
config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\n[e]\nf = ${a.c}"""
config = Config().from_str(config_str, overrides={"a.c.d": 20})
assert config["a"]["c"]["d"] == 20
assert config["e"]["f"] == {"d": 20}
def test_config_reserved_aliases():
"""Test that the auto-generated pydantic schemas auto-alias reserved
attributes like "validate" that would otherwise cause NameError."""
@my_registry.cats("catsie.with_alias")
def catsie_with_alias(validate: StrictBool = False):
return validate
cfg = {"@cats": "catsie.with_alias", "validate": True}
resolved = my_registry.resolve({"test": cfg})
filled = my_registry.fill({"test": cfg})
assert resolved["test"] is True
assert filled["test"] == cfg
cfg = {"@cats": "catsie.with_alias", "validate": 20}
with pytest.raises(ConfigValidationError):
my_registry.resolve({"test": cfg})
@pytest.mark.skipif(not has_numpy, reason="needs numpy")
@pytest.mark.parametrize("d", [".", ":"])
def test_config_no_interpolation(d):
"""Test that interpolation is correctly preserved. The parametrized
value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
valid. The double {{ }} in the config strings are required to prevent the
references from being interpreted as an actual f-string variable.
"""
c_str = f"""[a]\nb = 1\n\n[c]\nd = ${{a{d}b}}\ne = \"hello${{a{d}b}}"\nf = ${{a}}"""
config = Config().from_str(c_str, interpolate=False)
assert not config.is_interpolated
assert config["c"]["d"] == f"${{a{d}b}}"
assert config["c"]["e"] == f'"hello${{a{d}b}}"'
assert config["c"]["f"] == "${a}"
config2 = Config().from_str(config.to_str(), interpolate=True)
assert config2.is_interpolated
assert config2["c"]["d"] == 1
assert config2["c"]["e"] == "hello1"
assert config2["c"]["f"] == {"b": 1}
config3 = config.interpolate()
assert config3.is_interpolated
assert config3["c"]["d"] == 1
assert config3["c"]["e"] == "hello1"
assert config3["c"]["f"] == {"b": 1}
# Bad non-serializable value
cfg = {"x": {"y": numpy.asarray([[1, 2], [4, 5]], dtype="f"), "z": f"${{x{d}y}}"}}
with pytest.raises(ConfigValidationError):
Config(cfg).interpolate()
def test_config_no_interpolation_registry():
config_str = """[a]\nbad = true\n[b]\n@cats = "catsie.v1"\nevil = ${a:bad}\n\n[c]\n d = ${b}"""
config = Config().from_str(config_str, interpolate=False)
assert not config.is_interpolated
assert config["b"]["evil"] == "${a:bad}"
assert config["c"]["d"] == "${b}"
filled = my_registry.fill(config)
resolved = my_registry.resolve(config)
assert resolved["b"] == "scratch!"
assert resolved["c"]["d"] == "scratch!"
assert filled["b"]["evil"] == "${a:bad}"
assert filled["b"]["cute"] is True
assert filled["c"]["d"] == "${b}"
interpolated = filled.interpolate()
assert interpolated.is_interpolated
assert interpolated["b"]["evil"] is True
assert interpolated["c"]["d"] == interpolated["b"]
config = Config().from_str(config_str, interpolate=True)
assert config.is_interpolated
filled = my_registry.fill(config)
resolved = my_registry.resolve(config)
assert resolved["b"] == "scratch!"
assert resolved["c"]["d"] == "scratch!"
assert filled["b"]["evil"] is True
assert filled["c"]["d"] == filled["b"]
# Resolving a non-interpolated filled config
config = Config().from_str(config_str, interpolate=False)
assert not config.is_interpolated
filled = my_registry.fill(config)
assert not filled.is_interpolated
assert filled["c"]["d"] == "${b}"
resolved = my_registry.resolve(filled)
assert resolved["c"]["d"] == "scratch!"
def test_config_deep_merge():
config = {"a": "hello", "b": {"c": "d"}}
defaults = {"a": "world", "b": {"c": "e", "f": "g"}}
merged = Config(defaults).merge(config)
assert len(merged) == 2
assert merged["a"] == "hello"
assert merged["b"] == {"c": "d", "f": "g"}
config = {"a": "hello", "b": {"@test": "x", "foo": 1}}
defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2}
assert merged["c"] == 100
config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100}
defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@test": "x", "foo": 1}
assert merged["c"] == 100
# Test that leaving out the factory just adds to existing
config = {"a": "hello", "b": {"foo": 1}, "c": 100}
defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2}
assert merged["c"] == 100
# Test that switching to a different factory prevents the default from being added
config = {"a": "hello", "b": {"@foo": 1}, "c": 100}
defaults = {"a": "world", "b": {"@bar": "y"}}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@foo": 1}
assert merged["c"] == 100
config = {"a": "hello", "b": {"@foo": 1}, "c": 100}
defaults = {"a": "world", "b": "y"}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@foo": 1}
assert merged["c"] == 100
def test_config_deep_merge_variables():
config_str = """[a]\nb= 1\nc = 2\n\n[d]\ne = ${a:b}"""
defaults_str = """[a]\nx = 100\n\n[d]\ny = 500"""
config = Config().from_str(config_str, interpolate=False)
defaults = Config().from_str(defaults_str)
merged = defaults.merge(config)
assert merged["a"] == {"b": 1, "c": 2, "x": 100}
assert merged["d"] == {"e": "${a:b}", "y": 500}
assert merged.interpolate()["d"] == {"e": 1, "y": 500}
# With variable in defaults: overwritten by new value
config = Config().from_str("""[a]\nb= 1\nc = 2""")
defaults = Config().from_str("""[a]\nb = 100\nc = ${a:b}""", interpolate=False)
merged = defaults.merge(config)
assert merged["a"]["c"] == 2
@pytest.mark.skipif(not has_numpy, reason="needs numpy")
def test_config_to_str_roundtrip():
cfg = {"cfg": {"foo": False}}
config_str = Config(cfg).to_str()
assert config_str == "[cfg]\nfoo = false"
config = Config().from_str(config_str)
assert dict(config) == cfg
cfg = {"cfg": {"foo": "false"}}
config_str = Config(cfg).to_str()
assert config_str == '[cfg]\nfoo = "false"'
config = Config().from_str(config_str)
assert dict(config) == cfg
# Bad non-serializable value
cfg = {"cfg": {"x": numpy.asarray([[1, 2, 3, 4], [4, 5, 3, 4]], dtype="f")}}
config = Config(cfg)
with pytest.raises(ConfigValidationError):
config.to_str()
# Roundtrip with variables: preserve variables correctly (quoted/unquoted)
config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = "${a:b}\""""
config = Config().from_str(config_str, interpolate=False)
assert config.to_str() == config_str
def test_config_is_interpolated():
"""Test that a config object correctly reports whether it's interpolated."""
config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = ${a}"""
config = Config().from_str(config_str, interpolate=False)
assert not config.is_interpolated
config = config.merge(Config({"x": {"y": "z"}}))
assert not config.is_interpolated
config = Config(config)
assert not config.is_interpolated
config = config.interpolate()
assert config.is_interpolated
config = config.merge(Config().from_str(config_str, interpolate=False))
assert not config.is_interpolated
@pytest.mark.parametrize(
"section_order,expected_str,expected_keys",
[
# fmt: off
([], "[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[h]\ni = 5\n\n[j]\nk = 6", ["a", "h", "j"]),
(["j", "h", "a"], "[j]\nk = 6\n\n[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4", ["j", "h", "a"]),
(["h"], "[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[j]\nk = 6", ["h", "a", "j"])
# fmt: on
],
)
def test_config_serialize_custom_sort(section_order, expected_str, expected_keys):
cfg = {
"j": {"k": 6},
"a": {"b": 1, "d": {"e": 3}, "c": 2, "f": {"g": 4}},
"h": {"i": 5},
}
cfg_str = Config(cfg).to_str()
assert Config(cfg, section_order=section_order).to_str() == expected_str
keys = list(Config(section_order=section_order).from_str(cfg_str).keys())
assert keys == expected_keys
keys = list(Config(cfg, section_order=section_order).keys())
assert keys == expected_keys
def test_config_custom_sort_preserve():
"""Test that sort order is preserved when merging and copying configs,
or when configs are filled and resolved."""
cfg = {"x": {}, "y": {}, "z": {}}
section_order = ["y", "z", "x"]
expected = "[y]\n\n[z]\n\n[x]"
config = Config(cfg, section_order=section_order)
assert config.to_str() == expected
config2 = config.copy()
assert config2.to_str() == expected
config3 = config.merge({"a": {}})
assert config3.to_str() == f"{expected}\n\n[a]"
config4 = Config(config)
assert config4.to_str() == expected
config_str = """[a]\nb = 1\n[c]\n@cats = "catsie.v1"\nevil = true\n\n[t]\n x = 2"""
section_order = ["c", "a", "t"]
config5 = Config(section_order=section_order).from_str(config_str)
assert list(config5.keys()) == section_order
filled = my_registry.fill(config5)
assert filled.section_order == section_order
def test_config_pickle():
config = Config({"foo": "bar"}, section_order=["foo", "bar", "baz"])
data = pickle.dumps(config)
config_new = pickle.loads(data)
assert config_new == {"foo": "bar"}
assert config_new.section_order == ["foo", "bar", "baz"]
def test_config_fill_extra_fields():
"""Test that filling a config from a schema removes extra fields."""
class TestSchemaContent(BaseModel):
a: str
b: int
class Config:
extra = "forbid"
class TestSchema(BaseModel):
cfg: TestSchemaContent
config = Config({"cfg": {"a": "1", "b": 2, "c": True}})
with pytest.raises(ConfigValidationError):
my_registry.fill(config, schema=TestSchema)
filled = my_registry.fill(config, schema=TestSchema, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2}
config2 = config.interpolate()
filled = my_registry.fill(config2, schema=TestSchema, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2}
config3 = Config({"cfg": {"a": "1", "b": 2, "c": True}}, is_interpolated=False)
filled = my_registry.fill(config3, schema=TestSchema, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2}
class TestSchemaContent2(BaseModel):
a: str
b: int
class Config:
extra = "allow"
class TestSchema2(BaseModel):
cfg: TestSchemaContent2
filled = my_registry.fill(config, schema=TestSchema2, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2, "c": True}
def test_config_validation_error_custom():
class Schema(BaseModel):
hello: int
world: int
config = {"hello": 1, "world": "hi!"}
with pytest.raises(ConfigValidationError) as exc_info:
my_registry._fill(config, Schema)
e1 = exc_info.value
assert e1.title == "Config validation error"
assert e1.desc is None
assert not e1.parent
assert e1.show_config is True
assert len(e1.errors) == 1
assert e1.errors[0]["loc"] == ("world",)
assert e1.errors[0]["msg"] == "value is not a valid integer"
assert e1.errors[0]["type"] == "type_error.integer"
assert e1.error_types == set(["type_error.integer"])
# Create a new error with overrides
title = "Custom error"
desc = "Some error description here"
e2 = ConfigValidationError.from_error(e1, title=title, desc=desc, show_config=False)
assert e2.errors == e1.errors
assert e2.error_types == e1.error_types
assert e2.title == title
assert e2.desc == desc
assert e2.show_config is False
assert e1.text != e2.text
def test_config_parsing_error():
config_str = "[a]\nb c"
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
def test_config_fill_without_resolve():
class BaseSchema(BaseModel):
catsie: int
config = {"catsie": {"@cats": "catsie.v1", "evil": False}}
filled = my_registry.fill(config)
resolved = my_registry.resolve(config)
assert resolved["catsie"] == "meow"
assert filled["catsie"]["cute"] is True
with pytest.raises(ConfigValidationError):
my_registry.resolve(config, schema=BaseSchema)
filled2 = my_registry.fill(config, schema=BaseSchema)
assert filled2["catsie"]["cute"] is True
resolved = my_registry.resolve(filled2)
assert resolved["catsie"] == "meow"
# With unavailable function
class BaseSchema2(BaseModel):
catsie: Any
other: int = 12
config = {"catsie": {"@cats": "dog", "evil": False}}
filled3 = my_registry.fill(config, schema=BaseSchema2)
assert filled3["catsie"] == config["catsie"]
assert filled3["other"] == 12
def test_config_dataclasses():
cat = Cat("testcat", value_in=1, value_out=2)
config = {"cfg": {"@cats": "catsie.v3", "arg": cat}}
result = my_registry.resolve(config)["cfg"]
assert isinstance(result, Cat)
assert result.name == cat.name
assert result.value_in == cat.value_in
assert result.value_out == cat.value_out
@pytest.mark.parametrize(
"greeting,value,expected",
[
# simple substitution should go fine
[342, "${vars.a}", int],
["342", "${vars.a}", str],
["everyone", "${vars.a}", str],
],
)
def test_config_interpolates(greeting, value, expected):
str_cfg = f"""
[project]
my_par = {value}
[vars]
a = "something"
"""
overrides = {"vars.a": greeting}
cfg = Config().from_str(str_cfg, overrides=overrides)
assert type(cfg["project"]["my_par"]) == expected
@pytest.mark.parametrize(
"greeting,value,expected",
[
# fmt: off
# simple substitution should go fine
["hello 342", "${vars.a}", "hello 342"],
["hello everyone", "${vars.a}", "hello everyone"],
["hello tout le monde", "${vars.a}", "hello tout le monde"],
["hello 42", "${vars.a}", "hello 42"],
# substituting an element in a list
["hello 342", "[1, ${vars.a}, 3]", "hello 342"],
["hello everyone", "[1, ${vars.a}, 3]", "hello everyone"],
["hello tout le monde", "[1, ${vars.a}, 3]", "hello tout le monde"],
["hello 42", "[1, ${vars.a}, 3]", "hello 42"],
# substituting part of a string
[342, "hello ${vars.a}", "hello 342"],
["everyone", "hello ${vars.a}", "hello everyone"],
["tout le monde", "hello ${vars.a}", "hello tout le monde"],
pytest.param("42", "hello ${vars.a}", "hello 42", marks=pytest.mark.xfail),
# substituting part of a implicit string inside a list
[342, "[1, hello ${vars.a}, 3]", "hello 342"],
["everyone", "[1, hello ${vars.a}, 3]", "hello everyone"],
["tout le monde", "[1, hello ${vars.a}, 3]", "hello tout le monde"],
pytest.param("42", "[1, hello ${vars.a}, 3]", "hello 42", marks=pytest.mark.xfail),
# substituting part of a explicit string inside a list
[342, "[1, 'hello ${vars.a}', '3']", "hello 342"],
["everyone", "[1, 'hello ${vars.a}', '3']", "hello everyone"],
["tout le monde", "[1, 'hello ${vars.a}', '3']", "hello tout le monde"],
pytest.param("42", "[1, 'hello ${vars.a}', '3']", "hello 42", marks=pytest.mark.xfail),
# more complicated example
[342, "[{'name':'x','script':['hello ${vars.a}']}]", "hello 342"],
["everyone", "[{'name':'x','script':['hello ${vars.a}']}]", "hello everyone"],
["tout le monde", "[{'name':'x','script':['hello ${vars.a}']}]", "hello tout le monde"],
pytest.param("42", "[{'name':'x','script':['hello ${vars.a}']}]", "hello 42", marks=pytest.mark.xfail),
# fmt: on
],
)
def test_config_overrides(greeting, value, expected):
str_cfg = f"""
[project]
commands = {value}
[vars]
a = "world"
"""
overrides = {"vars.a": greeting}
assert "${vars.a}" in str_cfg
cfg = Config().from_str(str_cfg, overrides=overrides)
assert expected in str(cfg)
catalogue-2.1.0/catalogue/tests/util.py 0000664 0000000 0000000 00000007715 14261040534 0020124 0 ustar 00root root 0000000 0000000 """
Registered functions used for config tests.
"""
import contextlib
import dataclasses
import shutil
import tempfile
from pathlib import Path
from typing import Iterable, List, Union, Generator, Callable, Tuple, Generic, TypeVar, Optional, Any
import numpy
from pydantic.types import StrictBool
import catalogue
from catalogue.config import config
from catalogue.config.util import partial
FloatOrSeq = Union[float, List[float], Generator]
InT = TypeVar("InT")
OutT = TypeVar("OutT")
@dataclasses.dataclass
class Cat(Generic[InT, OutT]):
name: str
value_in: InT
value_out: OutT
class my_registry(config.catalogue_registry):
cats = catalogue.registry.create("config_tests", "cats", entry_points=False)
optimizers = catalogue.registry.create("config_tests", "optimizers", entry_points=False)
schedules = catalogue.registry.create("config_tests", "schedules", entry_points=False)
initializers = catalogue.registry.create("config_tests", "initializers", entry_points=False)
layers = catalogue.registry.create("config_tests", "layers", entry_points=False)
@my_registry.cats.register("catsie.v1")
def catsie_v1(evil: StrictBool, cute: bool = True) -> str:
if evil:
return "scratch!"
else:
return "meow"
@my_registry.cats.register("catsie.v2")
def catsie_v2(evil: StrictBool, cute: bool = True, cute_level: int = 1) -> str:
if evil:
return "scratch!"
else:
if cute_level > 2:
return "meow <3"
return "meow"
@my_registry.cats("catsie.v3")
def catsie(arg: Cat) -> Cat:
return arg
@my_registry.optimizers("Adam.v1")
def Adam(
learn_rate: FloatOrSeq = 0.001,
*,
beta1: FloatOrSeq = .001,
beta2: FloatOrSeq = .001,
use_averages: bool = True,
):
"""
Mocks optimizer generation. Note that the returned object is not actually an optimizer. This function is merely used
to illustrate how to use the function registry, e.g. with thinc.
"""
@dataclasses.dataclass
class Optimizer:
learn_rate: FloatOrSeq
beta1: FloatOrSeq
beta2: FloatOrSeq
use_averages: bool
return Optimizer(learn_rate=learn_rate, beta1=beta1, beta2=beta2, use_averages=use_averages)
@my_registry.schedules("warmup_linear.v1")
def warmup_linear(
initial_rate: float, warmup_steps: int, total_steps: int
) -> Iterable[float]:
"""Generate a series, starting from an initial rate, and then with a warmup
period, and then a linear decline. Used for learning rates.
"""
step = 0
while True:
if step < warmup_steps:
factor = step / max(1, warmup_steps)
else:
factor = max(
0.0, (total_steps - step) / max(1.0, total_steps - warmup_steps)
)
yield factor * initial_rate
step += 1
def uniform_init(
shape: Tuple[int, ...], *, lo: float = -0.1, hi: float = 0.1
) -> List[float]:
return numpy.random.uniform(lo, hi, shape).tolist()
@my_registry.initializers("uniform_init.v1")
def configure_uniform_init(
*, lo: float = -0.1, hi: float = 0.1
) -> Callable[[List[float]], List[float]]:
return partial(uniform_init, lo=lo, hi=hi)
@my_registry.cats("generic_cat.v1")
def generic_cat(cat: Cat[int, int]) -> Cat[int, int]:
cat.name = "generic_cat"
return cat
@my_registry.cats("int_cat.v1")
def int_cat(
value_in: Optional[int] = None, value_out: Optional[int] = None
) -> Cat[Optional[int], Optional[int]]:
""" Instantiates cat with integer values. """
return Cat(name="int_cat", value_in=value_in, value_out=value_out)
@my_registry.optimizers.register("my_cool_optimizer.v1")
def make_my_optimizer(learn_rate: List[float], beta1: float):
return Adam(learn_rate, beta1=beta1)
@my_registry.schedules("my_cool_repetitive_schedule.v1")
def decaying(base_rate: float, repeat: int) -> List[float]:
return repeat * [base_rate]
@contextlib.contextmanager
def make_tempdir():
d = Path(tempfile.mkdtemp())
yield d
shutil.rmtree(str(d))
catalogue-2.1.0/requirements.txt 0000664 0000000 0000000 00000000230 14261040534 0016734 0 ustar 00root root 0000000 0000000 zipp>=0.5; python_version < "3.8"
typing-extensions>=3.6.4; python_version < "3.8"
pytest>=4.6.5
mypy
pydantic
types-dataclasses
typing_extensions
srsly catalogue-2.1.0/setup.cfg 0000664 0000000 0000000 00000002460 14261040534 0015300 0 ustar 00root root 0000000 0000000 [metadata]
version = 2.1.0
description = Lightweight function registries for your library
url = https://github.com/explosion/catalogue
author = Explosion
author_email = contact@explosion.ai
license = MIT
long_description = file: README.md
long_description_content_type = text/markdown
classifiers =
Development Status :: 5 - Production/Stable
Environment :: Console
Intended Audience :: Developers
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Operating System :: POSIX :: Linux
Operating System :: MacOS :: MacOS X
Operating System :: Microsoft :: Windows
Programming Language :: Python :: 3
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Topic :: Scientific/Engineering
[options]
zip_safe = true
include_package_data = true
python_requires = >=3.6
install_requires =
zipp>=0.5; python_version < "3.8"
typing-extensions>=3.6.4; python_version < "3.8"
[sdist]
formats = gztar
[flake8]
ignore = E203, E266, E501, E731, W503
max-line-length = 80
select = B,C,E,F,W,T4,B9
exclude =
.env,
.git,
__pycache__,
[mypy]
ignore_missing_imports = True
no_implicit_optional = True
catalogue-2.1.0/setup.py 0000664 0000000 0000000 00000000231 14261040534 0015163 0 ustar 00root root 0000000 0000000 #!/usr/bin/env python
if __name__ == "__main__":
from setuptools import setup, find_packages
setup(name="catalogue", packages=find_packages())