pax_global_header 0000666 0000000 0000000 00000000064 14357240256 0014522 g ustar 00root root 0000000 0000000 52 comment=f950fe02d0de11ad035e7b91848fe3a76d6b7830
confection-0.0.4/ 0000775 0000000 0000000 00000000000 14357240256 0013652 5 ustar 00root root 0000000 0000000 confection-0.0.4/.gitignore 0000664 0000000 0000000 00000001546 14357240256 0015650 0 ustar 00root root 0000000 0000000 tmp/
.pytest_cache
.vscode
.mypy_cache
.prettierrc
.python-version
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
.env/
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
# Sphinx documentation
docs/_build/
# PyBuilder
target/
#Ipython Notebook
.ipynb_checkpoints
# Pycharm project files
*.idea
confection-0.0.4/LICENSE 0000664 0000000 0000000 00000002061 14357240256 0014656 0 ustar 00root root 0000000 0000000 MIT License
Copyright (c) 2019 ExplosionAI GmbH
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
confection-0.0.4/MANIFEST.in 0000664 0000000 0000000 00000000054 14357240256 0015407 0 ustar 00root root 0000000 0000000 include LICENSE
include confection/py.typed
confection-0.0.4/README.md 0000664 0000000 0000000 00000042600 14357240256 0015133 0 ustar 00root root 0000000 0000000
# Confection: The sweetest config system for Python
`confection` :candy: is a lightweight library that offers a **configuration system** letting you conveniently describe arbitrary
trees of objects.
Configuration is a huge challenge for machine-learning code because you may want to expose almost any
detail of any function as a hyperparameter. The setting you want to expose might be arbitrarily far
down in your call stack, so it might need to pass all the way through the CLI or REST API,
through any number of intermediate functions, affecting the interface of everything along the way.
And then once those settings are added, they become hard to remove later. Default values also
become hard to change without breaking backwards compatibility.
To solve this problem, `confection` offers a config system that lets you easily describe arbitrary trees of objects.
The objects can be created via function calls you register using a simple decorator syntax. You can even version the
functions you create, allowing you to make improvements without breaking backwards compatibility. The most similar
config system we’re aware of is [Gin](https://github.com/google/gin-config), which uses a similar syntax, and also
allows you to link the configuration system to functions in your code using a decorator. `confection`'s config system is
simpler and emphasizes a different workflow via a subset of Gin’s functionality.
[](https://dev.azure.com/explosion-ai/public/_build?definitionId=28)
[](https://github.com/explosion/confection/releases)
[](https://pypi.org/project/confection/)
[](https://anaconda.org/conda-forge/confection)
[](https://github.com/ambv/black)
## ⏳ Installation
```bash
pip install confection
```
```bash
conda install -c conda-forge confection
```
## 👩💻 Usage
The configuration system parses a `.cfg` file like
```ini
[training]
patience = 10
dropout = 0.2
use_vectors = false
[training.logging]
level = "INFO"
[nlp]
# This uses the value of training.use_vectors
use_vectors = ${training.use_vectors}
lang = "en"
```
and resolves it to a `Dict`:
```json
{
"training": {
"patience": 10,
"dropout": 0.2,
"use_vectors": false,
"logging": {
"level": "INFO"
}
},
"nlp": {
"use_vectors": false,
"lang": "en"
}
}
```
The config is divided into sections, with the section name in square brackets – for
example, `[training]`. Within the sections, config values can be assigned to keys using `=`. Values can also be referenced
from other sections using the dot notation and placeholders indicated by the dollar sign and curly braces. For example,
`${training.use_vectors}` will receive the value of use_vectors in the training block. This is useful for settings that
are shared across components.
The config format has three main differences from Python’s built-in `configparser`:
1. JSON-formatted values. `confection` passes all values through `json.loads` to interpret them. You can use atomic
values like strings, floats, integers or booleans, or you can use complex objects such as lists or maps.
2. Structured sections. `confection` uses a dot notation to build nested sections. If you have a section named
`[section.subsection]`, `confection` will parse that into a nested structure, placing subsection within section.
3. References to registry functions. If a key starts with `@`, `confection` will interpret its value as the name of a
function registry, load the function registered for that name and pass in the rest of the block as arguments. If type
hints are available on the function, the argument values (and return value of the function) will be validated against
them. This lets you express complex configurations, like a training pipeline where `batch_size` is populated by a
function that yields floats.
There’s no pre-defined scheme you have to follow; how you set up the top-level sections is up to you. At the end of
it, you’ll receive a dictionary with the values that you can use in your script – whether it’s complete initialized
functions, or just basic settings.
For instance, let’s say you want to define a new optimizer. You'd define its arguments in `config.cfg` like so:
```ini
[optimizer]
@optimizers = "my_cool_optimizer.v1"
learn_rate = 0.001
gamma = 1e-8
```
To load and parse this configuration:
```python
import dataclasses
from typing import Union, Iterable
import catalogue
from confection import registry, Config
# Create a new registry.
registry.optimizers = catalogue.create("confection", "optimizers", entry_points=False)
# Define a dummy optimizer class.
@dataclasses.dataclass
class MyCoolOptimizer:
learn_rate: float
gamma: float
@registry.optimizers.register("my_cool_optimizer.v1")
def make_my_optimizer(learn_rate: Union[float, Iterable[float]], gamma: float):
return MyCoolOptimizer(learn_rate, gamma)
# Load the config file from disk, resolve it and fetch the instantiated optimizer object.
config = Config().from_disk("./config.cfg")
resolved = registry.resolve(config)
optimizer = resolved["optimizer"] # MyCoolOptimizer(learn_rate=0.001, gamma=1e-08)
```
Under the hood, `confection` will look up the `"my_cool_optimizer.v1"` function in the "optimizers" registry and then
call it with the arguments `learn_rate` and `gamma`. If the function has type annotations, it will also validate the
input. For instance, if `learn_rate` is annotated as a float and the config defines a string, `confection` will raise an
error.
The Thinc documentation offers further information on the configuration system:
- [recursive blocks](https://thinc.ai/docs/usage-config#registry-recursive)
- [defining variable positional arguments](https://thinc.ai/docs/usage-config#registries-args)
- [using interpolation](https://thinc.ai/docs/usage-config#config-interpolation)
- [using custom registries](https://thinc.ai/docs/usage-config#registries-custom)
- [advanced type annotations with Pydantic](https://thinc.ai/docs/usage-config#advanced-types)
- [using base schemas](https://thinc.ai/docs/usage-config#advanced-types-base-schema)
- [filling a configuration with defaults](https://thinc.ai/docs/usage-config#advanced-types-fill-defaults)
## 🎛 API
### class `Config`
This class holds the model and training [configuration](https://thinc.ai/docs/usage-config) and can load and save the
INI-style configuration format from/to a string, file or bytes. The `Config` class is a subclass of `dict` and uses
Python’s `ConfigParser` under the hood.
#### method `Config.__init__`
Initialize a new `Config` object with optional data.
```python
from confection import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
```
| Argument | Type | Description |
| ----------------- | ----------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `data` | `Optional[Union[Dict[str, Any], Config]]` | Optional data to initialize the config with. |
| `section_order` | `Optional[List[str]]` | Top-level section names, in order, used to sort the saved and loaded config. All other sections will be sorted alphabetically. |
| `is_interpolated` | `Optional[bool]` | Whether the config is interpolated or whether it contains variables. Read from the `data` if it’s an instance of `Config` and otherwise defaults to `True`. |
#### method `Config.from_str`
Load the config from a string.
```python
from confection import Config
config_str = """
[training]
patience = 10
dropout = 0.2
"""
config = Config().from_str(config_str)
print(config["training"]) # {'patience': 10, 'dropout': 0.2}}
```
| Argument | Type | Description |
| ------------- | ---------------- | -------------------------------------------------------------------------------------------------------------------- |
| `text` | `str` | The string config to load. |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| `overrides` | `Dict[str, Any]` | Overrides for values and sections. Keys are provided in dot notation, e.g. `"training.dropout"` mapped to the value. |
| **RETURNS** | `Config` | The loaded config. |
#### method `Config.to_str`
Load the config from a string.
```python
from confection import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
print(config.to_str()) # '[training]\npatience = 10\n\ndropout = 0.2'
```
| Argument | Type | Description |
| ------------- | ------ | --------------------------------------------------------------------------- |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| **RETURNS** | `str` | The string config. |
#### method `Config.to_bytes`
Serialize the config to a byte string.
```python
from confection import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
config_bytes = config.to_bytes()
print(config_bytes) # b'[training]\npatience = 10\n\ndropout = 0.2'
```
| Argument | Type | Description |
| ------------- | ---------------- | -------------------------------------------------------------------------------------------------------------------- |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| `overrides` | `Dict[str, Any]` | Overrides for values and sections. Keys are provided in dot notation, e.g. `"training.dropout"` mapped to the value. |
| **RETURNS** | `str` | The serialized config. |
#### method `Config.from_bytes`
Load the config from a byte string.
```python
from confection import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
config_bytes = config.to_bytes()
new_config = Config().from_bytes(config_bytes)
```
| Argument | Type | Description |
| ------------- | -------- | --------------------------------------------------------------------------- |
| `bytes_data` | `bool` | The data to load. |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| **RETURNS** | `Config` | The loaded config. |
#### method `Config.to_disk`
Serialize the config to a file.
```python
from confection import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
config.to_disk("./config.cfg")
```
| Argument | Type | Description |
| ------------- | ------------------ | --------------------------------------------------------------------------- |
| `path` | `Union[Path, str]` | The file path. |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
#### method `Config.from_disk`
Load the config from a file.
```python
from confection import Config
config = Config({"training": {"patience": 10, "dropout": 0.2}})
config.to_disk("./config.cfg")
new_config = Config().from_disk("./config.cfg")
```
| Argument | Type | Description |
| ------------- | ------------------ | -------------------------------------------------------------------------------------------------------------------- |
| `path` | `Union[Path, str]` | The file path. |
| `interpolate` | `bool` | Whether to interpolate variables like `${section.key}`. Defaults to `True`. |
| `overrides` | `Dict[str, Any]` | Overrides for values and sections. Keys are provided in dot notation, e.g. `"training.dropout"` mapped to the value. |
| **RETURNS** | `Config` | The loaded config. |
#### method `Config.copy`
Deep-copy the config.
| Argument | Type | Description |
| ----------- | -------- | ------------------ |
| **RETURNS** | `Config` | The copied config. |
#### method `Config.interpolate`
Interpolate variables like `${section.value}` or `${section.subsection}` and return a copy of the config with interpolated
values. Can be used if a config is loaded with `interpolate=False`, e.g. via `Config.from_str`.
```python
from confection import Config
config_str = """
[hyper_params]
dropout = 0.2
[training]
dropout = ${hyper_params.dropout}
"""
config = Config().from_str(config_str, interpolate=False)
print(config["training"]) # {'dropout': '${hyper_params.dropout}'}}
config = config.interpolate()
print(config["training"]) # {'dropout': 0.2}}
```
| Argument | Type | Description |
| ----------- | -------- | ---------------------------------------------- |
| **RETURNS** | `Config` | A copy of the config with interpolated values. |
##### method `Config.merge`
Deep-merge two config objects, using the current config as the default. Only merges sections and dictionaries and not
other values like lists. Values that are provided in the updates are overwritten in the base config, and any new values
or sections are added. If a config value is a variable like `${section.key}` (e.g. if the config was loaded with
`interpolate=False)`, **the variable is preferred**, even if the updates provide a different value. This ensures that variable
references aren’t destroyed by a merge.
> :warning: Note that blocks that refer to registered functions using the `@` syntax are only merged if they are
> referring to the same functions. Otherwise, merging could easily produce invalid configs, since different functions
> can take different arguments. If a block refers to a different function, it’s overwritten.
```python
from confection import Config
base_config_str = """
[training]
patience = 10
dropout = 0.2
"""
update_config_str = """
[training]
dropout = 0.1
max_epochs = 2000
"""
base_config = Config().from_str(base_config_str)
update_config = Config().from_str(update_config_str)
merged = Config(base_config).merge(update_config)
print(merged["training"]) # {'patience': 10, 'dropout': 0.1, 'max_epochs': 2000}
```
| Argument | Type | Description |
| ----------- | ------------------------------- | --------------------------------------------------- |
| `overrides` | `Union[Dict[str, Any], Config]` | The updates to merge into the config. |
| **RETURNS** | `Config` | A new config instance containing the merged config. |
### Config Attributes
| Argument | Type | Description |
| ----------------- | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `is_interpolated` | `bool` | Whether the config values have been interpolated. Defaults to `True` and is set to `False` if a config is loaded with `interpolate=False`, e.g. using `Config.from_str`. |
confection-0.0.4/azure-pipelines.yml 0000664 0000000 0000000 00000004635 14357240256 0017521 0 ustar 00root root 0000000 0000000 trigger:
batch: true
branches:
include:
- '*'
jobs:
- job: 'Test'
strategy:
matrix:
Python36Linux:
imageName: 'ubuntu-20.04'
python.version: '3.6'
Python36Windows:
imageName: 'windows-2019'
python.version: '3.6'
Python37Linux:
imageName: 'ubuntu-latest'
python.version: '3.7'
Python37Windows:
imageName: 'windows-latest'
python.version: '3.7'
Python37Mac:
imageName: 'macos-latest'
python.version: '3.7'
Python38Linux:
imageName: 'ubuntu-latest'
python.version: '3.8'
Python38Windows:
imageName: 'windows-latest'
python.version: '3.8'
Python38Mac:
imageName: 'macos-latest'
python.version: '3.8'
Python39Linux:
imageName: 'ubuntu-latest'
python.version: '3.9'
Python39Windows:
imageName: 'windows-latest'
python.version: '3.9'
Python39Mac:
imageName: 'macos-latest'
python.version: '3.9'
Python310Linux:
imageName: 'ubuntu-latest'
python.version: '3.10'
Python310Windows:
imageName: 'windows-latest'
python.version: '3.10'
Python310Mac:
imageName: 'macos-latest'
python.version: '3.10'
Python311Linux:
imageName: 'ubuntu-latest'
python.version: '3.11'
Python311Windows:
imageName: 'windows-latest'
python.version: '3.11'
Python311Mac:
imageName: 'macos-latest'
python.version: '3.11'
maxParallel: 4
pool:
vmImage: $(imageName)
steps:
- task: UsePythonVersion@0
inputs:
versionSpec: '$(python.version)'
architecture: 'x64'
- script: |
pip install -U -r requirements.txt
python setup.py sdist
displayName: 'Build sdist'
- script: python -m mypy confection
displayName: 'Run mypy'
condition: ne(variables['python.version'], '3.6')
- task: DeleteFiles@1
inputs:
contents: 'confection'
displayName: 'Delete source directory'
- bash: |
SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1)
pip install dist/$SDIST
displayName: 'Install from sdist'
- script: python -m pytest --pyargs confection
displayName: 'Run tests'
- bash: |
pip install hypothesis
python -c "import confection; import hypothesis"
displayName: 'Test for conflicts'
confection-0.0.4/bin/ 0000775 0000000 0000000 00000000000 14357240256 0014422 5 ustar 00root root 0000000 0000000 confection-0.0.4/bin/push-tags.sh 0000664 0000000 0000000 00000000537 14357240256 0016676 0 ustar 00root root 0000000 0000000 #!/usr/bin/env bash
set -e
# Insist repository is clean
git diff-index --quiet HEAD
git checkout $1
git pull origin $1
git push origin $1
version=$(grep "version = " setup.cfg)
version=${version/version = }
version=${version/\'/}
version=${version/\'/}
version=${version/\"/}
version=${version/\"/}
git tag "v$version"
git push origin "v$version" confection-0.0.4/confection/ 0000775 0000000 0000000 00000000000 14357240256 0016001 5 ustar 00root root 0000000 0000000 confection-0.0.4/confection/__init__.py 0000664 0000000 0000000 00000132253 14357240256 0020120 0 ustar 00root root 0000000 0000000 from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping
from typing import Iterable, Sequence, cast
from types import GeneratorType
from dataclasses import dataclass
from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH
from configparser import InterpolationMissingOptionError, InterpolationSyntaxError
from configparser import NoSectionError, NoOptionError, InterpolationDepthError
from configparser import ParsingError
from pathlib import Path
from pydantic import BaseModel, create_model, ValidationError, Extra
from pydantic.main import ModelMetaclass
from pydantic.fields import ModelField
import srsly
import catalogue
import inspect
import io
import copy
import re
from .util import Decorator
# Field used for positional arguments, e.g. [section.*.xyz]. The alias is
# required for the schema (shouldn't clash with user-defined arg names)
ARGS_FIELD = "*"
ARGS_FIELD_ALIAS = "VARIABLE_POSITIONAL_ARGS"
# Aliases for fields that would otherwise shadow pydantic attributes. Can be any
# string, so we're using name + space so it looks the same in error messages etc.
RESERVED_FIELDS = {"validate": "validate\u0020"}
# Internal prefix used to mark section references for custom interpolation
SECTION_PREFIX = "__SECTION__:"
# Values that shouldn't be loaded during interpolation because it'd cause
# even explicit string values to be incorrectly parsed as bools/None etc.
JSON_EXCEPTIONS = ("true", "false", "null")
# Regex to detect whether a value contains a variable
VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}")
class CustomInterpolation(ExtendedInterpolation):
def before_read(self, parser, section, option, value):
# If we're dealing with a quoted string as the interpolation value,
# make sure we load and unquote it so we don't end up with '"value"'
try:
json_value = srsly.json_loads(value)
if isinstance(json_value, str) and json_value not in JSON_EXCEPTIONS:
value = json_value
except Exception:
pass
return super().before_read(parser, section, option, value)
def before_get(self, parser, section, option, value, defaults):
# Mostly copy-pasted from the built-in configparser implementation.
L = []
self.interpolate(parser, option, L, value, section, defaults, 1)
return "".join(L)
def interpolate(self, parser, option, accum, rest, section, map, depth):
# Mostly copy-pasted from the built-in configparser implementation.
# We need to overwrite this method so we can add special handling for
# block references :( All values produced here should be strings –
# we need to wait until the whole config is interpreted anyways so
# filling in incomplete values here is pointless. All we need is the
# section reference so we can fetch it later.
rawval = parser.get(section, option, raw=True, fallback=rest)
if depth > MAX_INTERPOLATION_DEPTH:
raise InterpolationDepthError(option, section, rawval)
while rest:
p = rest.find("$")
if p < 0:
accum.append(rest)
return
if p > 0:
accum.append(rest[:p])
rest = rest[p:]
# p is no longer used
c = rest[1:2]
if c == "$":
accum.append("$")
rest = rest[2:]
elif c == "{":
# We want to treat both ${a:b} and ${a.b} the same
m = self._KEYCRE.match(rest)
if m is None:
err = f"bad interpolation variable reference {rest}"
raise InterpolationSyntaxError(option, section, err)
orig_var = m.group(1)
path = orig_var.replace(":", ".").rsplit(".", 1)
rest = rest[m.end() :]
sect = section
opt = option
try:
if len(path) == 1:
opt = parser.optionxform(path[0])
if opt in map:
v = map[opt]
else:
# We have block reference, store it as a special key
section_name = parser[parser.optionxform(path[0])]._name
v = self._get_section_name(section_name)
elif len(path) == 2:
sect = path[0]
opt = parser.optionxform(path[1])
fallback = "__FALLBACK__"
v = parser.get(sect, opt, raw=True, fallback=fallback)
# If a variable doesn't exist, try again and treat the
# reference as a section
if v == fallback:
v = self._get_section_name(parser[f"{sect}.{opt}"]._name)
else:
err = f"More than one ':' found: {rest}"
raise InterpolationSyntaxError(option, section, err)
except (KeyError, NoSectionError, NoOptionError):
raise InterpolationMissingOptionError(
option, section, rawval, orig_var
) from None
if "$" in v:
new_map = dict(parser.items(sect, raw=True))
self.interpolate(parser, opt, accum, v, sect, new_map, depth + 1)
else:
accum.append(v)
else:
err = "'$' must be followed by '$' or '{', " "found: %r" % (rest,)
raise InterpolationSyntaxError(option, section, err)
def _get_section_name(self, name: str) -> str:
"""Generate the name of a section. Note that we use a quoted string here
so we can use section references within lists and load the list as
JSON. Since section references can't be used within strings, we don't
need the quoted vs. unquoted distinction like we do for variables.
Examples (assuming section = {"foo": 1}):
- value: ${section.foo} -> value: 1
- value: "hello ${section.foo}" -> value: "hello 1"
- value: ${section} -> value: {"foo": 1}
- value: "${section}" -> value: {"foo": 1}
- value: "hello ${section}" -> invalid
"""
return f'"{SECTION_PREFIX}{name}"'
def get_configparser(interpolate: bool = True):
config = ConfigParser(interpolation=CustomInterpolation() if interpolate else None)
# Preserve case of keys: https://stackoverflow.com/a/1611877/6400719
config.optionxform = str # type: ignore
return config
class Config(dict):
"""This class holds the model and training configuration and can load and
save the TOML-style configuration format from/to a string, file or bytes.
The Config class is a subclass of dict and uses Python's ConfigParser
under the hood.
"""
is_interpolated: bool
def __init__(
self,
data: Optional[Union[Dict[str, Any], "ConfigParser", "Config"]] = None,
*,
is_interpolated: Optional[bool] = None,
section_order: Optional[List[str]] = None,
) -> None:
"""Initialize a new Config object with optional data."""
dict.__init__(self)
if data is None:
data = {}
if not isinstance(data, (dict, Config, ConfigParser)):
raise ValueError(
f"Can't initialize Config with data. Expected dict, Config or "
f"ConfigParser but got: {type(data)}"
)
# Whether the config has been interpolated. We can use this to check
# whether we need to interpolate again when it's resolved. We assume
# that a config is interpolated by default.
if is_interpolated is not None:
self.is_interpolated = is_interpolated
elif isinstance(data, Config):
self.is_interpolated = data.is_interpolated
else:
self.is_interpolated = True
if section_order is not None:
self.section_order = section_order
elif isinstance(data, Config):
self.section_order = data.section_order
else:
self.section_order = []
# Update with data
self.update(self._sort(data))
def interpolate(self) -> "Config":
"""Interpolate a config. Returns a copy of the object."""
# This is currently the most effective way because we need our custom
# to_str logic to run in order to re-serialize the values so we can
# interpolate them again. ConfigParser.read_dict will just call str()
# on all values, which isn't enough.
return Config().from_str(self.to_str())
def interpret_config(self, config: "ConfigParser") -> None:
"""Interpret a config, parse nested sections and parse the values
as JSON. Mostly used internally and modifies the config in place.
"""
self._validate_sections(config)
# Sort sections by depth, so that we can iterate breadth-first. This
# allows us to check that we're not expanding an undefined block.
get_depth = lambda item: len(item[0].split("."))
for section, values in sorted(config.items(), key=get_depth):
if section == "DEFAULT":
# Skip [DEFAULT] section so it doesn't cause validation error
continue
parts = section.split(".")
node = self
for part in parts[:-1]:
if part == "*":
node = node.setdefault(part, {})
elif part not in node:
err_title = f"Error parsing config section. Perhaps a section name is wrong?"
err = [{"loc": parts, "msg": f"Section '{part}' is not defined"}]
raise ConfigValidationError(
config=self, errors=err, title=err_title
)
else:
node = node[part]
if not isinstance(node, dict):
# Happens if both value *and* subsection were defined for a key
err = [{"loc": parts, "msg": "found conflicting values"}]
err_cfg = f"{self}\n{({part: dict(values)})}"
raise ConfigValidationError(config=err_cfg, errors=err)
# Set the default section
node = node.setdefault(parts[-1], {})
if not isinstance(node, dict):
# Happens if both value *and* subsection were defined for a key
err = [{"loc": parts, "msg": "found conflicting values"}]
err_cfg = f"{self}\n{({part: dict(values)})}"
raise ConfigValidationError(config=err_cfg, errors=err)
try:
keys_values = list(values.items())
except InterpolationMissingOptionError as e:
raise ConfigValidationError(desc=f"{e}") from None
for key, value in keys_values:
config_v = config.get(section, key)
node[key] = self._interpret_value(config_v)
self.replace_section_refs(self)
def replace_section_refs(
self, config: Union[Dict[str, Any], "Config"], parent: str = ""
) -> None:
"""Replace references to section blocks in the final config."""
for key, value in config.items():
key_parent = f"{parent}.{key}".strip(".")
if isinstance(value, dict):
self.replace_section_refs(value, parent=key_parent)
elif isinstance(value, list):
config[key] = [
self._get_section_ref(v, parent=[parent, key]) for v in value
]
else:
config[key] = self._get_section_ref(value, parent=[parent, key])
def _interpret_value(self, value: Any) -> Any:
"""Interpret a single config value."""
result = try_load_json(value)
# If value is a string and it contains a variable, use original value
# (not interpreted string, which could lead to double quotes:
# ${x.y} -> "${x.y}" -> "'${x.y}'"). Make sure to check it's a string,
# so we're not keeping lists as strings.
# NOTE: This currently can't handle uninterpolated values like [${x.y}]!
if isinstance(result, str) and VARIABLE_RE.search(value):
result = value
if isinstance(result, list):
return [self._interpret_value(v) for v in result]
return result
def _get_section_ref(self, value: Any, *, parent: List[str] = []) -> Any:
"""Get a single section reference."""
if isinstance(value, str) and value.startswith(f'"{SECTION_PREFIX}'):
value = try_load_json(value)
if isinstance(value, str) and value.startswith(SECTION_PREFIX):
parts = value.replace(SECTION_PREFIX, "").split(".")
result = self
for item in parts:
try:
result = result[item]
except (KeyError, TypeError): # This should never happen
err_title = "Error parsing reference to config section"
err_msg = f"Section '{'.'.join(parts)}' is not defined"
err = [{"loc": parts, "msg": err_msg}]
raise ConfigValidationError(
config=self, errors=err, title=err_title
) from None
return result
elif isinstance(value, str) and SECTION_PREFIX in value:
# String value references a section (either a dict or return
# value of promise). We can't allow this, since variables are
# always interpolated *before* configs are resolved.
err_desc = (
"Can't reference whole sections or return values of function "
"blocks inside a string or list\n\nYou can change your variable to "
"reference a value instead. Keep in mind that it's not "
"possible to interpolate the return value of a registered "
"function, since variables are interpolated when the config "
"is loaded, and registered functions are resolved afterwards."
)
err = [{"loc": parent, "msg": "uses section variable in string or list"}]
raise ConfigValidationError(errors=err, desc=err_desc)
return value
def copy(self) -> "Config":
"""Deepcopy the config."""
try:
config = copy.deepcopy(self)
except Exception as e:
raise ValueError(f"Couldn't deep-copy config: {e}") from e
return Config(
config,
is_interpolated=self.is_interpolated,
section_order=self.section_order,
)
def merge(
self, updates: Union[Dict[str, Any], "Config"], remove_extra: bool = False
) -> "Config":
"""Deep merge the config with updates, using current as defaults."""
defaults = self.copy()
updates = Config(updates).copy()
merged = deep_merge_configs(updates, defaults, remove_extra=remove_extra)
return Config(
merged,
is_interpolated=defaults.is_interpolated and updates.is_interpolated,
section_order=defaults.section_order,
)
def _sort(
self, data: Union["Config", "ConfigParser", Dict[str, Any]]
) -> Dict[str, Any]:
"""Sort sections using the currently defined sort order. Sort
sections by index on section order, if available, then alphabetic, and
account for subsections, which should always follow their parent.
"""
sort_map = {section: i for i, section in enumerate(self.section_order)}
sort_key = lambda x: (
sort_map.get(x[0].split(".")[0], len(sort_map)),
_mask_positional_args(x[0]),
)
return dict(sorted(data.items(), key=sort_key))
def _set_overrides(self, config: "ConfigParser", overrides: Dict[str, Any]) -> None:
"""Set overrides in the ConfigParser before config is interpreted."""
err_title = "Error parsing config overrides"
for key, value in overrides.items():
err_msg = "not a section value that can be overwritten"
err = [{"loc": key.split("."), "msg": err_msg}]
if "." not in key:
raise ConfigValidationError(errors=err, title=err_title)
section, option = key.rsplit(".", 1)
# Check for section and accept if option not in config[section]
if section not in config:
raise ConfigValidationError(errors=err, title=err_title)
config.set(section, option, try_dump_json(value, overrides))
def _validate_sections(self, config: "ConfigParser") -> None:
# If the config defines top-level properties that are not sections (e.g.
# if config was constructed from dict), those values would be added as
# [DEFAULTS] and included in *every other section*. This is usually not
# what we want and it can lead to very confusing results.
default_section = config.defaults()
if default_section:
err_title = "Found config values without a top-level section"
err_msg = "not part of a section"
err = [{"loc": [k], "msg": err_msg} for k in default_section]
raise ConfigValidationError(errors=err, title=err_title)
def from_str(
self, text: str, *, interpolate: bool = True, overrides: Dict[str, Any] = {}
) -> "Config":
"""Load the config from a string."""
config = get_configparser(interpolate=interpolate)
if overrides:
config = get_configparser(interpolate=False)
try:
config.read_string(text)
except ParsingError as e:
desc = f"Make sure the sections and values are formatted correctly.\n\n{e}"
raise ConfigValidationError(desc=desc) from None
config._sections = self._sort(config._sections)
self._set_overrides(config, overrides)
self.clear()
self.interpret_config(config)
if overrides and interpolate:
# do the interpolation. Avoids recursion because the new call from_str call will have overrides as empty
self = self.interpolate()
self.is_interpolated = interpolate
return self
def to_str(self, *, interpolate: bool = True) -> str:
"""Write the config to a string."""
flattened = get_configparser(interpolate=interpolate)
queue: List[Tuple[tuple, "Config"]] = [(tuple(), self)]
for path, node in queue:
section_name = ".".join(path)
is_kwarg = path and path[-1] != "*"
if is_kwarg and not flattened.has_section(section_name):
# Always create sections for non-'*' sections, not only if
# they have leaf entries, as we don't want to expand
# blocks that are undefined
flattened.add_section(section_name)
for key, value in node.items():
if hasattr(value, "items"):
# Reference to a function with no arguments, serialize
# inline as a dict and don't create new section
if (
registry.is_promise(value)
and len(value) == 1
and is_kwarg
):
flattened.set(section_name, key, try_dump_json(value, node))
else:
queue.append((path + (key,), value))
else:
flattened.set(section_name, key, try_dump_json(value, node))
# Order so subsection follow parent (not all sections, then all subs etc.)
flattened._sections = self._sort(flattened._sections)
self._validate_sections(flattened)
string_io = io.StringIO()
flattened.write(string_io)
return string_io.getvalue().strip()
def to_bytes(self, *, interpolate: bool = True) -> bytes:
"""Serialize the config to a byte string."""
return self.to_str(interpolate=interpolate).encode("utf8")
def from_bytes(
self,
bytes_data: bytes,
*,
interpolate: bool = True,
overrides: Dict[str, Any] = {},
) -> "Config":
"""Load the config from a byte string."""
return self.from_str(
bytes_data.decode("utf8"), interpolate=interpolate, overrides=overrides
)
def to_disk(self, path: Union[str, Path], *, interpolate: bool = True):
"""Serialize the config to a file."""
path = Path(path) if isinstance(path, str) else path
with path.open("w", encoding="utf8") as file_:
file_.write(self.to_str(interpolate=interpolate))
def from_disk(
self,
path: Union[str, Path],
*,
interpolate: bool = True,
overrides: Dict[str, Any] = {},
) -> "Config":
"""Load config from a file."""
path = Path(path) if isinstance(path, str) else path
with path.open("r", encoding="utf8") as file_:
text = file_.read()
return self.from_str(text, interpolate=interpolate, overrides=overrides)
def _mask_positional_args(name: str) -> List[Optional[str]]:
"""Create a section name representation that masks names
of positional arguments to retain their order in sorts."""
stable_name = cast(List[Optional[str]], name.split("."))
# Remove names of sections that are a positional argument.
for i in range(1, len(stable_name)):
if stable_name[i - 1] == "*":
stable_name[i] = None
return stable_name
def try_load_json(value: str) -> Any:
"""Load a JSON string if possible, otherwise default to original value."""
try:
return srsly.json_loads(value)
except Exception:
return value
def try_dump_json(value: Any, data: Union[Dict[str, dict], Config, str] = "") -> str:
"""Dump a config value as JSON and output user-friendly error if it fails."""
# Special case if we have a variable: it's already a string so don't dump
# to preserve ${x:y} vs. "${x:y}"
if isinstance(value, str) and VARIABLE_RE.search(value):
return value
if isinstance(value, str) and value.replace(".", "", 1).isdigit():
# Work around values that are strings but numbers
value = f'"{value}"'
try:
return srsly.json_dumps(value)
except Exception as e:
err_msg = (
f"Couldn't serialize config value of type {type(value)}: {e}. Make "
f"sure all values in your config are JSON-serializable. If you want "
f"to include Python objects, use a registered function that returns "
f"the object instead."
)
raise ConfigValidationError(config=data, desc=err_msg) from e
def deep_merge_configs(
config: Union[Dict[str, Any], Config],
defaults: Union[Dict[str, Any], Config],
*,
remove_extra: bool = False,
) -> Union[Dict[str, Any], Config]:
"""Deep merge two configs."""
if remove_extra:
# Filter out values in the original config that are not in defaults
keys = list(config.keys())
for key in keys:
if key not in defaults:
del config[key]
for key, value in defaults.items():
if isinstance(value, dict):
node = config.setdefault(key, {})
if not isinstance(node, dict):
continue
value_promises = [k for k in value if k.startswith("@")]
value_promise = value_promises[0] if value_promises else None
node_promises = [k for k in node if k.startswith("@")] if node else []
node_promise = node_promises[0] if node_promises else None
# We only update the block from defaults if it refers to the same
# registered function
if (
value_promise
and node_promise
and (
value_promise in node
and node[value_promise] != value[value_promise]
)
):
continue
if node_promise and (
node_promise not in value or node[node_promise] != value[node_promise]
):
continue
defaults = deep_merge_configs(node, value, remove_extra=remove_extra)
elif key not in config:
config[key] = value
return config
class ConfigValidationError(ValueError):
def __init__(
self,
*,
config: Optional[Union[Config, Dict[str, Dict[str, Any]], str]] = None,
errors: Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]] = tuple(),
title: Optional[str] = "Config validation error",
desc: Optional[str] = None,
parent: Optional[str] = None,
show_config: bool = True,
) -> None:
"""Custom error for validating configs.
config (Union[Config, Dict[str, Dict[str, Any]], str]): The
config the validation error refers to.
errors (Union[Sequence[Mapping[str, Any]], Iterable[Dict[str, Any]]]):
A list of errors as dicts with keys "loc" (list of strings
describing the path of the value), "msg" (validation message
to show) and optional "type" (mostly internals).
Same format as produced by pydantic's validation error (e.errors()).
title (str): The error title.
desc (str): Optional error description, displayed below the title.
parent (str): Optional parent to use as prefix for all error locations.
For example, parent "element" will result in "element -> a -> b".
show_config (bool): Whether to print the whole config with the error.
ATTRIBUTES:
config (Union[Config, Dict[str, Dict[str, Any]], str]): The config.
errors (Iterable[Dict[str, Any]]): The errors.
error_types (Set[str]): All "type" values defined in the errors, if
available. This is most relevant for the pydantic errors that define
types like "type_error.integer". This attribute makes it easy to
check if a config validation error includes errors of a certain
type, e.g. to log additional information or custom help messages.
title (str): The title.
desc (str): The description.
parent (str): The parent.
show_config (bool): Whether to show the config.
text (str): The formatted error text.
"""
self.config = config
self.errors = errors
self.title = title
self.desc = desc
self.parent = parent
self.show_config = show_config
self.error_types = set()
for error in self.errors:
err_type = error.get("type")
if err_type:
self.error_types.add(err_type)
self.text = self._format()
ValueError.__init__(self, self.text)
@classmethod
def from_error(
cls,
err: "ConfigValidationError",
title: Optional[str] = None,
desc: Optional[str] = None,
parent: Optional[str] = None,
show_config: Optional[bool] = None,
) -> "ConfigValidationError":
"""Create a new ConfigValidationError based on an existing error, e.g.
to re-raise it with different settings. If no overrides are provided,
the values from the original error are used.
err (ConfigValidationError): The original error.
title (str): Overwrite error title.
desc (str): Overwrite error description.
parent (str): Overwrite error parent.
show_config (bool): Overwrite whether to show config.
RETURNS (ConfigValidationError): The new error.
"""
return cls(
config=err.config,
errors=err.errors,
title=title if title is not None else err.title,
desc=desc if desc is not None else err.desc,
parent=parent if parent is not None else err.parent,
show_config=show_config if show_config is not None else err.show_config,
)
def _format(self) -> str:
"""Format the error message."""
loc_divider = "->"
data = []
for error in self.errors:
err_loc = f" {loc_divider} ".join([str(p) for p in error.get("loc", [])])
if self.parent:
err_loc = f"{self.parent} {loc_divider} {err_loc}"
data.append((err_loc, error.get("msg")))
result = []
if self.title:
result.append(self.title)
if self.desc:
result.append(self.desc)
if data:
result.append("\n".join([f"{entry[0]}\t{entry[1]}" for entry in data]))
if self.config and self.show_config:
result.append(f"{self.config}")
return "\n\n" + "\n".join(result)
def alias_generator(name: str) -> str:
"""Generate field aliases in promise schema."""
# Underscore fields are not allowed in model, so use alias
if name == ARGS_FIELD_ALIAS:
return ARGS_FIELD
# Auto-alias fields that shadow base model attributes
if name in RESERVED_FIELDS:
return RESERVED_FIELDS[name]
return name
def copy_model_field(field: ModelField, type_: Any) -> ModelField:
"""Copy a model field and assign a new type, e.g. to accept an Any type
even though the original value is typed differently.
"""
return ModelField(
name=field.name,
type_=type_,
class_validators=field.class_validators,
model_config=field.model_config,
default=field.default,
default_factory=field.default_factory,
required=field.required,
)
class EmptySchema(BaseModel):
class Config:
extra = "allow"
arbitrary_types_allowed = True
class _PromiseSchemaConfig:
extra = "forbid"
arbitrary_types_allowed = True
alias_generator = alias_generator
@dataclass
class Promise:
registry: str
name: str
args: List[str]
kwargs: Dict[str, Any]
class registry:
@classmethod
def has(cls, registry_name: str, func_name: str) -> bool:
"""Check whether a function is available in a registry."""
if not hasattr(cls, registry_name):
return False
reg = getattr(cls, registry_name)
return func_name in reg
@classmethod
def get(cls, registry_name: str, func_name: str) -> Callable:
"""Get a registered function from a given registry."""
if not hasattr(cls, registry_name):
raise ValueError(f"Unknown registry: '{registry_name}'")
reg = getattr(cls, registry_name)
func = reg.get(func_name)
if func is None:
raise ValueError(f"Could not find '{func_name}' in '{registry_name}'")
return func
@classmethod
def resolve(
cls,
config: Union[Config, Dict[str, Dict[str, Any]]],
*,
schema: Type[BaseModel] = EmptySchema,
overrides: Dict[str, Any] = {},
validate: bool = True,
) -> Dict[str, Any]:
resolved, _ = cls._make(
config, schema=schema, overrides=overrides, validate=validate, resolve=True
)
return resolved
@classmethod
def fill(
cls,
config: Union[Config, Dict[str, Dict[str, Any]]],
*,
schema: Type[BaseModel] = EmptySchema,
overrides: Dict[str, Any] = {},
validate: bool = True,
):
_, filled = cls._make(
config, schema=schema, overrides=overrides, validate=validate, resolve=False
)
return filled
@classmethod
def _make(
cls,
config: Union[Config, Dict[str, Dict[str, Any]]],
*,
schema: Type[BaseModel] = EmptySchema,
overrides: Dict[str, Any] = {},
resolve: bool = True,
validate: bool = True,
) -> Tuple[Dict[str, Any], Config]:
"""Unpack a config dictionary and create two versions of the config:
a resolved version with objects from the registry created recursively,
and a filled version with all references to registry functions left
intact, but filled with all values and defaults based on the type
annotations. If validate=True, the config will be validated against the
type annotations of the registered functions referenced in the config
(if available) and/or the schema (if available).
"""
# Valid: {"optimizer": {"@optimizers": "my_cool_optimizer", "rate": 1.0}}
# Invalid: {"@optimizers": "my_cool_optimizer", "rate": 1.0}
if cls.is_promise(config):
err_msg = "The top-level config object can't be a reference to a registered function."
raise ConfigValidationError(config=config, errors=[{"msg": err_msg}])
# If a Config was loaded with interpolate=False, we assume it needs to
# be interpolated first, otherwise we take it at face value
is_interpolated = not isinstance(config, Config) or config.is_interpolated
section_order = config.section_order if isinstance(config, Config) else None
orig_config = config
if not is_interpolated:
config = Config(orig_config).interpolate()
filled, _, resolved = cls._fill(
config, schema, validate=validate, overrides=overrides, resolve=resolve
)
filled = Config(filled, section_order=section_order)
# Check that overrides didn't include invalid properties not in config
if validate:
cls._validate_overrides(filled, overrides)
# Merge the original config back to preserve variables if we started
# with a config that wasn't interpolated. Here, we prefer variables to
# allow auto-filling a non-interpolated config without destroying
# variable references.
if not is_interpolated:
filled = filled.merge(
Config(orig_config, is_interpolated=False), remove_extra=True
)
return dict(resolved), filled
@classmethod
def _fill(
cls,
config: Union[Config, Dict[str, Dict[str, Any]]],
schema: Type[BaseModel] = EmptySchema,
*,
validate: bool = True,
resolve: bool = True,
parent: str = "",
overrides: Dict[str, Dict[str, Any]] = {},
) -> Tuple[
Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any]
]:
"""Build three representations of the config:
1. All promises are preserved (just like config user would provide).
2. Promises are replaced by their return values. This is the validation
copy and will be parsed by pydantic. It lets us include hacks to
work around problems (e.g. handling of generators).
3. Final copy with promises replaced by their return values.
"""
filled: Dict[str, Any] = {}
validation: Dict[str, Any] = {}
final: Dict[str, Any] = {}
for key, value in config.items():
# If the field name is reserved, we use its alias for validation
v_key = RESERVED_FIELDS.get(key, key)
key_parent = f"{parent}.{key}".strip(".")
if key_parent in overrides:
value = overrides[key_parent]
config[key] = value
if cls.is_promise(value):
if key in schema.__fields__ and not resolve:
# If we're not resolving the config, make sure that the field
# expecting the promise is typed Any so it doesn't fail
# validation if it doesn't receive the function return value
field = schema.__fields__[key]
schema.__fields__[key] = copy_model_field(field, Any)
promise_schema = cls.make_promise_schema(value, resolve=resolve)
filled[key], validation[v_key], final[key] = cls._fill(
value,
promise_schema,
validate=validate,
resolve=resolve,
parent=key_parent,
overrides=overrides,
)
reg_name, func_name = cls.get_constructor(final[key])
args, kwargs = cls.parse_args(final[key])
if resolve:
# Call the function and populate the field value. We can't
# just create an instance of the type here, since this
# wouldn't work for generics / more complex custom types
getter = cls.get(reg_name, func_name)
# We don't want to try/except this and raise our own error
# here, because we want the traceback if the function fails.
getter_result = getter(*args, **kwargs)
else:
# We're not resolving and calling the function, so replace
# the getter_result with a Promise class
getter_result = Promise(
registry=reg_name, name=func_name, args=args, kwargs=kwargs
)
validation[v_key] = getter_result
final[key] = getter_result
if isinstance(validation[v_key], GeneratorType):
# If value is a generator we can't validate type without
# consuming it (which doesn't work if it's infinite – see
# schedule for examples). So we skip it.
validation[v_key] = []
elif hasattr(value, "items"):
field_type = EmptySchema
if key in schema.__fields__:
field = schema.__fields__[key]
field_type = field.type_
if not isinstance(field.type_, ModelMetaclass):
# If we don't have a pydantic schema and just a type
field_type = EmptySchema
filled[key], validation[v_key], final[key] = cls._fill(
value,
field_type,
validate=validate,
resolve=resolve,
parent=key_parent,
overrides=overrides,
)
if key == ARGS_FIELD and isinstance(validation[v_key], dict):
# If the value of variable positional args is a dict (e.g.
# created via config blocks), only use its values
validation[v_key] = list(validation[v_key].values())
final[key] = list(final[key].values())
else:
filled[key] = value
# Prevent pydantic from consuming generator if part of a union
validation[v_key] = (
value if not isinstance(value, GeneratorType) else []
)
final[key] = value
# Now that we've filled in all of the promises, update with defaults
# from schema, and validate if validation is enabled
exclude = []
if validate:
try:
result = schema.parse_obj(validation)
except ValidationError as e:
raise ConfigValidationError(
config=config, errors=e.errors(), parent=parent
) from None
else:
# Same as parse_obj, but without validation
result = schema.construct(**validation)
# If our schema doesn't allow extra values, we need to filter them
# manually because .construct doesn't parse anything
if schema.Config.extra in (Extra.forbid, Extra.ignore):
fields = schema.__fields__.keys()
exclude = [k for k in result.__fields_set__ if k not in fields]
exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()])
validation.update(result.dict(exclude=exclude_validation))
filled, final = cls._update_from_parsed(validation, filled, final)
if exclude:
filled = {k: v for k, v in filled.items() if k not in exclude}
validation = {k: v for k, v in validation.items() if k not in exclude}
final = {k: v for k, v in final.items() if k not in exclude}
return filled, validation, final
@classmethod
def _update_from_parsed(
cls, validation: Dict[str, Any], filled: Dict[str, Any], final: Dict[str, Any]
):
"""Update the final result with the parsed config like converted
values recursively.
"""
for key, value in validation.items():
if key in RESERVED_FIELDS.values():
continue # skip aliases for reserved fields
if key not in filled:
filled[key] = value
if key not in final:
final[key] = value
if isinstance(value, dict):
filled[key], final[key] = cls._update_from_parsed(
value, filled[key], final[key]
)
# Update final config with parsed value if they're not equal (in
# value and in type) but not if it's a generator because we had to
# replace that to validate it correctly
elif key == ARGS_FIELD:
continue # don't substitute if list of positional args
# Check numpy first, just in case. Use stringified type so that numpy dependency can be ditched.
elif str(type(value)) == "":
final[key] = value
elif (
value != final[key] or not isinstance(type(value), type(final[key]))
) and not isinstance(final[key], GeneratorType):
final[key] = value
return filled, final
@classmethod
def _validate_overrides(cls, filled: Config, overrides: Dict[str, Any]):
"""Validate overrides against a filled config to make sure there are
no references to properties that don't exist and weren't used."""
error_msg = "Invalid override: config value doesn't exist"
errors = []
for override_key in overrides.keys():
if not cls._is_in_config(override_key, filled):
errors.append({"msg": error_msg, "loc": [override_key]})
if errors:
raise ConfigValidationError(config=filled, errors=errors)
@classmethod
def _is_in_config(cls, prop: str, config: Union[Dict[str, Any], Config]):
"""Check whether a nested config property like "section.subsection.key"
is in a given config."""
tree = prop.split(".")
obj = dict(config)
while tree:
key = tree.pop(0)
if isinstance(obj, dict) and key in obj:
obj = obj[key]
else:
return False
return True
@classmethod
def is_promise(cls, obj: Any) -> bool:
"""Check whether an object is a "promise", i.e. contains a reference
to a registered function (via a key starting with `"@"`.
"""
if not hasattr(obj, "keys"):
return False
id_keys = [k for k in obj.keys() if k.startswith("@")]
if len(id_keys):
return True
return False
@classmethod
def get_constructor(cls, obj: Dict[str, Any]) -> Tuple[str, str]:
id_keys = [k for k in obj.keys() if k.startswith("@")]
if len(id_keys) != 1:
err_msg = f"A block can only contain one function registry reference. Got: {id_keys}"
raise ConfigValidationError(config=obj, errors=[{"msg": err_msg}])
else:
key = id_keys[0]
value = obj[key]
return (key[1:], value)
@classmethod
def parse_args(cls, obj: Dict[str, Any]) -> Tuple[List[Any], Dict[str, Any]]:
args = []
kwargs = {}
for key, value in obj.items():
if not key.startswith("@"):
if key == ARGS_FIELD:
args = value
elif key in RESERVED_FIELDS.values():
continue
else:
kwargs[key] = value
return args, kwargs
@classmethod
def make_promise_schema(
cls, obj: Dict[str, Any], *, resolve: bool = True
) -> Type[BaseModel]:
"""Create a schema for a promise dict (referencing a registry function)
by inspecting the function signature.
"""
reg_name, func_name = cls.get_constructor(obj)
if not resolve and not cls.has(reg_name, func_name):
return EmptySchema
func = cls.get(reg_name, func_name)
# Read the argument annotations and defaults from the function signature
id_keys = [k for k in obj.keys() if k.startswith("@")]
sig_args: Dict[str, Any] = {id_keys[0]: (str, ...)}
for param in inspect.signature(func).parameters.values():
# If no annotation is specified assume it's anything
annotation = param.annotation if param.annotation != param.empty else Any
# If no default value is specified assume that it's required
default = param.default if param.default != param.empty else ...
# Handle spread arguments and use their annotation as Sequence[whatever]
if param.kind == param.VAR_POSITIONAL:
spread_annot = Sequence[annotation] # type: ignore
sig_args[ARGS_FIELD_ALIAS] = (spread_annot, default)
else:
name = RESERVED_FIELDS.get(param.name, param.name)
sig_args[name] = (annotation, default)
sig_args["__config__"] = _PromiseSchemaConfig
return create_model("ArgModel", **sig_args)
__all__ = ["Config", "registry", "ConfigValidationError"]
confection-0.0.4/confection/py.typed 0000664 0000000 0000000 00000000000 14357240256 0017466 0 ustar 00root root 0000000 0000000 confection-0.0.4/confection/tests/ 0000775 0000000 0000000 00000000000 14357240256 0017143 5 ustar 00root root 0000000 0000000 confection-0.0.4/confection/tests/__init__.py 0000664 0000000 0000000 00000000000 14357240256 0021242 0 ustar 00root root 0000000 0000000 confection-0.0.4/confection/tests/conftest.py 0000664 0000000 0000000 00000000750 14357240256 0021344 0 ustar 00root root 0000000 0000000 import pytest
def pytest_addoption(parser):
parser.addoption("--slow", action="store_true", help="include slow tests")
@pytest.fixture()
def pathy_fixture():
pytest.importorskip("pathy")
import tempfile
import shutil
from pathy import use_fs, Pathy
temp_folder = tempfile.mkdtemp(prefix="thinc-pathy")
use_fs(temp_folder)
root = Pathy("gs://test-bucket")
root.mkdir(exist_ok=True)
yield root
use_fs(False)
shutil.rmtree(temp_folder)
confection-0.0.4/confection/tests/test_config.py 0000664 0000000 0000000 00000146336 14357240256 0022036 0 ustar 00root root 0000000 0000000 import inspect
import platform
import catalogue
import pytest
from typing import Dict, Optional, Iterable, Callable, Any, Union, List, Tuple
from types import GeneratorType
import pickle
from pydantic import BaseModel, StrictFloat, PositiveInt, constr
from pydantic.types import StrictBool
from confection import ConfigValidationError, Config
from confection.util import Generator, partial
from confection.tests.util import Cat, my_registry, make_tempdir
EXAMPLE_CONFIG = """
[optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
use_averages = true
[optimizer.learn_rate]
@schedules = "warmup_linear.v1"
initial_rate = 0.1
warmup_steps = 10000
total_steps = 100000
[pipeline]
[pipeline.classifier]
name = "classifier"
factory = "classifier"
[pipeline.classifier.model]
@layers = "ClassifierModel.v1"
hidden_depth = 1
hidden_width = 64
token_vector_width = 128
[pipeline.classifier.model.embedding]
@layers = "Embedding.v1"
width = ${pipeline.classifier.model:token_vector_width}
"""
OPTIMIZER_CFG = """
[optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
use_averages = true
[optimizer.learn_rate]
@schedules = "warmup_linear.v1"
initial_rate = 0.1
warmup_steps = 10000
total_steps = 100000
"""
class HelloIntsSchema(BaseModel):
hello: int
world: int
class Config:
extra = "forbid"
class DefaultsSchema(BaseModel):
required: int
optional: str = "default value"
class Config:
extra = "forbid"
class ComplexSchema(BaseModel):
outer_req: int
outer_opt: str = "default value"
level2_req: HelloIntsSchema
level2_opt: DefaultsSchema = DefaultsSchema(required=1)
good_catsie = {"@cats": "catsie.v1", "evil": False, "cute": True}
ok_catsie = {"@cats": "catsie.v1", "evil": False, "cute": False}
bad_catsie = {"@cats": "catsie.v1", "evil": True, "cute": True}
worst_catsie = {"@cats": "catsie.v1", "evil": True, "cute": False}
def test_validate_simple_config():
simple_config = {"hello": 1, "world": 2}
f, _, v = my_registry._fill(simple_config, HelloIntsSchema)
assert f == simple_config
assert v == simple_config
def test_invalidate_simple_config():
invalid_config = {"hello": 1, "world": "hi!"}
with pytest.raises(ConfigValidationError) as exc_info:
my_registry._fill(invalid_config, HelloIntsSchema)
error = exc_info.value
assert len(error.errors) == 1
assert "type_error.integer" in error.error_types
def test_invalidate_extra_args():
invalid_config = {"hello": 1, "world": 2, "extra": 3}
with pytest.raises(ConfigValidationError):
my_registry._fill(invalid_config, HelloIntsSchema)
def test_fill_defaults_simple_config():
valid_config = {"required": 1}
filled, _, v = my_registry._fill(valid_config, DefaultsSchema)
assert filled["required"] == 1
assert filled["optional"] == "default value"
invalid_config = {"optional": "some value"}
with pytest.raises(ConfigValidationError):
my_registry._fill(invalid_config, DefaultsSchema)
def test_fill_recursive_config():
valid_config = {"outer_req": 1, "level2_req": {"hello": 4, "world": 7}}
filled, _, validation = my_registry._fill(valid_config, ComplexSchema)
assert filled["outer_req"] == 1
assert filled["outer_opt"] == "default value"
assert filled["level2_req"]["hello"] == 4
assert filled["level2_req"]["world"] == 7
assert filled["level2_opt"]["required"] == 1
assert filled["level2_opt"]["optional"] == "default value"
def test_is_promise():
assert my_registry.is_promise(good_catsie)
assert not my_registry.is_promise({"hello": "world"})
assert not my_registry.is_promise(1)
invalid = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"}
assert my_registry.is_promise(invalid)
def test_get_constructor():
assert my_registry.get_constructor(good_catsie) == ("cats", "catsie.v1")
def test_parse_args():
args, kwargs = my_registry.parse_args(bad_catsie)
assert args == []
assert kwargs == {"evil": True, "cute": True}
def test_make_promise_schema():
schema = my_registry.make_promise_schema(good_catsie)
assert "evil" in schema.__fields__
assert "cute" in schema.__fields__
def test_validate_promise():
config = {"required": 1, "optional": good_catsie}
filled, _, validated = my_registry._fill(config, DefaultsSchema)
assert filled == config
assert validated == {"required": 1, "optional": "meow"}
def test_fill_validate_promise():
config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}}
filled, _, validated = my_registry._fill(config, DefaultsSchema)
assert filled["optional"]["cute"] is True
def test_fill_invalidate_promise():
config = {"required": 1, "optional": {"@cats": "catsie.v1", "evil": False}}
with pytest.raises(ConfigValidationError):
my_registry._fill(config, HelloIntsSchema)
config["optional"]["whiskers"] = True
with pytest.raises(ConfigValidationError):
my_registry._fill(config, DefaultsSchema)
def test_create_registry():
my_registry.dogs = catalogue.create(
my_registry.namespace, "dogs", entry_points=False
)
assert hasattr(my_registry, "dogs")
assert len(my_registry.dogs.get_all()) == 0
my_registry.dogs.register("good_boy.v1", func=lambda x: x)
assert len(my_registry.dogs.get_all()) == 1
def test_registry_methods():
with pytest.raises(ValueError):
my_registry.get("dfkoofkds", "catsie.v1")
my_registry.cats.register("catsie.v123")(None)
with pytest.raises(ValueError):
my_registry.get("cats", "catsie.v123")
def test_resolve_no_schema():
config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
result = my_registry.resolve({"cfg": config})["cfg"]
assert result["one"] == 1
assert result["two"] == {"three": "scratch!"}
with pytest.raises(ConfigValidationError):
config = {"two": {"three": {"@cats": "catsie.v1", "evil": "true"}}}
my_registry.resolve(config)
def test_resolve_schema():
class TestBaseSubSchema(BaseModel):
three: str
class TestBaseSchema(BaseModel):
one: PositiveInt
two: TestBaseSubSchema
class Config:
extra = "forbid"
class TestSchema(BaseModel):
cfg: TestBaseSchema
config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
my_registry.resolve({"cfg": config}, schema=TestSchema)
config = {"one": -1, "two": {"three": {"@cats": "catsie.v1", "evil": True}}}
with pytest.raises(ConfigValidationError):
# "one" is not a positive int
my_registry.resolve({"cfg": config}, schema=TestSchema)
config = {"one": 1, "two": {"four": {"@cats": "catsie.v1", "evil": True}}}
with pytest.raises(ConfigValidationError):
# "three" is required in subschema
my_registry.resolve({"cfg": config}, schema=TestSchema)
def test_resolve_schema_coerced():
class TestBaseSchema(BaseModel):
test1: str
test2: bool
test3: float
class TestSchema(BaseModel):
cfg: TestBaseSchema
config = {"test1": 123, "test2": 1, "test3": 5}
filled = my_registry.fill({"cfg": config}, schema=TestSchema)
result = my_registry.resolve({"cfg": config}, schema=TestSchema)
assert result["cfg"] == {"test1": "123", "test2": True, "test3": 5.0}
# This only affects the resolved config, not the filled config
assert filled["cfg"] == config
def test_read_config():
byte_string = EXAMPLE_CONFIG.encode("utf8")
cfg = Config().from_bytes(byte_string)
assert cfg["optimizer"]["beta1"] == 0.9
assert cfg["optimizer"]["learn_rate"]["initial_rate"] == 0.1
assert cfg["pipeline"]["classifier"]["factory"] == "classifier"
assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128
def test_optimizer_config():
cfg = Config().from_str(OPTIMIZER_CFG)
optimizer = my_registry.resolve(cfg, validate=True)["optimizer"]
assert optimizer.beta1 == 0.9
def test_config_to_str():
cfg = Config().from_str(OPTIMIZER_CFG)
assert cfg.to_str().strip() == OPTIMIZER_CFG.strip()
cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG)
assert cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_to_str_creates_intermediate_blocks():
cfg = Config({"optimizer": {"foo": {"bar": 1}}})
assert (
cfg.to_str().strip()
== """
[optimizer]
[optimizer.foo]
bar = 1
""".strip()
)
def test_config_roundtrip_bytes():
cfg = Config().from_str(OPTIMIZER_CFG)
cfg_bytes = cfg.to_bytes()
new_cfg = Config().from_bytes(cfg_bytes)
assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_roundtrip_disk():
cfg = Config().from_str(OPTIMIZER_CFG)
with make_tempdir() as path:
cfg_path = path / "config.cfg"
cfg.to_disk(cfg_path)
new_cfg = Config().from_disk(cfg_path)
assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture):
cfg = Config().from_str(OPTIMIZER_CFG)
cfg_path = pathy_fixture / "config.cfg"
cfg.to_disk(cfg_path)
new_cfg = Config().from_disk(cfg_path)
assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip()
def test_config_to_str_invalid_defaults():
"""Test that an error is raised if a config contains top-level keys without
a section that would otherwise be interpreted as [DEFAULT] (which causes
the values to be included in *all* other sections).
"""
cfg = {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}
with pytest.raises(ConfigValidationError):
Config(cfg).to_str()
config_str = "[DEFAULT]\none = 1"
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
def test_validation_custom_types():
def complex_args(
rate: StrictFloat,
steps: PositiveInt = 10, # type: ignore
log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR",
):
return None
my_registry.complex = catalogue.create(
my_registry.namespace, "complex", entry_points=False
)
my_registry.complex("complex.v1")(complex_args)
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"}
my_registry.resolve({"config": cfg})
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": -1, "log_level": "INFO"}
with pytest.raises(ConfigValidationError):
# steps is not a positive int
my_registry.resolve({"config": cfg})
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "none"}
with pytest.raises(ConfigValidationError):
# log_level is not a string matching the regex
my_registry.resolve({"config": cfg})
cfg = {"@complex": "complex.v1", "rate": 1.0, "steps": 20, "log_level": "INFO"}
with pytest.raises(ConfigValidationError):
# top-level object is promise
my_registry.resolve(cfg)
with pytest.raises(ConfigValidationError):
# top-level object is promise
my_registry.fill(cfg)
cfg = {"@complex": "complex.v1", "rate": 1.0, "@cats": "catsie.v1"}
with pytest.raises(ConfigValidationError):
# two constructors
my_registry.resolve({"config": cfg})
def test_validation_no_validate():
config = {"one": 1, "two": {"three": {"@cats": "catsie.v1", "evil": "false"}}}
result = my_registry.resolve({"cfg": config}, validate=False)
filled = my_registry.fill({"cfg": config}, validate=False)
assert result["cfg"]["one"] == 1
assert result["cfg"]["two"] == {"three": "scratch!"}
assert filled["cfg"]["two"]["three"]["evil"] == "false"
assert filled["cfg"]["two"]["three"]["cute"] is True
def test_validation_fill_defaults():
config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v1", "evil": "hello"}}}
result = my_registry.fill(config, validate=False)
assert len(result["cfg"]["two"]) == 3
with pytest.raises(ConfigValidationError):
# Required arg "evil" is not defined
my_registry.fill(config)
config = {"cfg": {"one": 1, "two": {"@cats": "catsie.v2", "evil": False}}}
# Fill in with new defaults
result = my_registry.fill(config)
assert len(result["cfg"]["two"]) == 4
assert result["cfg"]["two"]["evil"] is False
assert result["cfg"]["two"]["cute"] is True
assert result["cfg"]["two"]["cute_level"] == 1
def test_make_config_positional_args():
@my_registry.cats("catsie.v567")
def catsie_567(*args: Optional[str], foo: str = "bar"):
assert args[0] == "^_^"
assert args[1] == "^(*.*)^"
assert foo == "baz"
return args[0]
args = ["^_^", "^(*.*)^"]
cfg = {"config": {"@cats": "catsie.v567", "foo": "baz", "*": args}}
assert my_registry.resolve(cfg)["config"] == "^_^"
def test_make_config_positional_args_complex():
@my_registry.cats("catsie.v890")
def catsie_890(*args: Optional[Union[StrictBool, PositiveInt]]):
assert args[0] == 123
return args[0]
cfg = {"config": {"@cats": "catsie.v890", "*": [123, True, 1, False]}}
assert my_registry.resolve(cfg)["config"] == 123
cfg = {"config": {"@cats": "catsie.v890", "*": [123, "True"]}}
with pytest.raises(ConfigValidationError):
# "True" is not a valid boolean or positive int
my_registry.resolve(cfg)
def test_positional_args_to_from_string():
cfg = """[a]\nb = 1\n* = ["foo","bar"]"""
assert Config().from_str(cfg).to_str() == cfg
cfg = """[a]\nb = 1\n\n[a.*.bar]\ntest = 2\n\n[a.*.foo]\ntest = 1"""
assert Config().from_str(cfg).to_str() == cfg
@my_registry.cats("catsie.v666")
def catsie_666(*args, meow=False):
return args
cfg = """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]"""
filled = my_registry.fill(Config().from_str(cfg)).to_str()
assert filled == """[a]\n@cats = "catsie.v666"\n* = ["foo","bar"]\nmeow = false"""
resolved = my_registry.resolve(Config().from_str(cfg))
assert resolved == {"a": ("foo", "bar")}
cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\nx = 1"""
filled = my_registry.fill(Config().from_str(cfg)).to_str()
assert filled == """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\nx = 1"""
resolved = my_registry.resolve(Config().from_str(cfg))
assert resolved == {"a": ({"x": 1},)}
@my_registry.cats("catsie.v777")
def catsie_777(y: int = 1):
return "meow" * y
cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777\""""
filled = my_registry.fill(Config().from_str(cfg)).to_str()
expected = """[a]\n@cats = "catsie.v666"\nmeow = false\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 1"""
assert filled == expected
cfg = """[a]\n@cats = "catsie.v666"\n\n[a.*.foo]\n@cats = "catsie.v777"\ny = 3"""
result = my_registry.resolve(Config().from_str(cfg))
assert result == {"a": ("meowmeowmeow",)}
def test_validation_generators_iterable():
@my_registry.optimizers("test_optimizer.v1")
def test_optimizer_v1(rate: float) -> None:
return None
@my_registry.schedules("test_schedule.v1")
def test_schedule_v1(some_value: float = 1.0) -> Iterable[float]:
while True:
yield some_value
config = {"optimizer": {"@optimizers": "test_optimizer.v1", "rate": 0.1}}
my_registry.resolve(config)
def test_validation_unset_type_hints():
"""Test that unset type hints are handled correctly (and treated as Any)."""
@my_registry.optimizers("test_optimizer.v2")
def test_optimizer_v2(rate, steps: int = 10) -> None:
return None
config = {"test": {"@optimizers": "test_optimizer.v2", "rate": 0.1, "steps": 20}}
my_registry.resolve(config)
def test_validation_bad_function():
@my_registry.optimizers("bad.v1")
def bad() -> None:
raise ValueError("This is an error in the function")
return None
@my_registry.optimizers("good.v1")
def good() -> None:
return None
# Bad function
config = {"test": {"@optimizers": "bad.v1"}}
with pytest.raises(ValueError):
my_registry.resolve(config)
# Bad function call
config = {"test": {"@optimizers": "good.v1", "invalid_arg": 1}}
with pytest.raises(ConfigValidationError):
my_registry.resolve(config)
def test_objects_from_config():
config = {
"optimizer": {
"@optimizers": "my_cool_optimizer.v1",
"beta1": 0.2,
"learn_rate": {
"@schedules": "my_cool_repetitive_schedule.v1",
"base_rate": 0.001,
"repeat": 4,
},
}
}
optimizer = my_registry.resolve(config)["optimizer"]
assert optimizer.beta1 == 0.2
assert optimizer.learn_rate == [0.001] * 4
def test_partials_from_config():
"""Test that functions registered with partial applications are handled
correctly (e.g. initializers)."""
numpy = pytest.importorskip("numpy")
def uniform_init(
shape: Tuple[int, ...], *, lo: float = -0.1, hi: float = 0.1
) -> List[float]:
return numpy.random.uniform(lo, hi, shape).tolist()
@my_registry.initializers("uniform_init.v1")
def configure_uniform_init(
*, lo: float = -0.1, hi: float = 0.1
) -> Callable[[List[float]], List[float]]:
return partial(uniform_init, lo=lo, hi=hi)
name = "uniform_init.v1"
cfg = {"test": {"@initializers": name, "lo": -0.2}}
func = my_registry.resolve(cfg)["test"]
assert hasattr(func, "__call__")
# The partial will still have lo as an arg, just with default
assert len(inspect.signature(func).parameters) == 3
# Make sure returned partial function has correct value set
assert inspect.signature(func).parameters["lo"].default == -0.2
# Actually call the function and verify
assert numpy.asarray(func((2, 3))).shape == (2, 3)
# Make sure validation still works
bad_cfg = {"test": {"@initializers": name, "lo": [0.5]}}
with pytest.raises(ConfigValidationError):
my_registry.resolve(bad_cfg)
bad_cfg = {"test": {"@initializers": name, "lo": -0.2, "other": 10}}
with pytest.raises(ConfigValidationError):
my_registry.resolve(bad_cfg)
def test_partials_from_config_nested():
"""Test that partial functions are passed correctly to other registered
functions that consume them (e.g. initializers -> layers)."""
def test_initializer(a: int, b: int = 1) -> int:
return a * b
@my_registry.initializers("test_initializer.v1")
def configure_test_initializer(b: int = 1) -> Callable[[int], int]:
return partial(test_initializer, b=b)
@my_registry.layers("test_layer.v1")
def test_layer(init: Callable[[int], int], c: int = 1) -> Callable[[int], int]:
return lambda x: x + init(c)
cfg = {
"@layers": "test_layer.v1",
"c": 5,
"init": {"@initializers": "test_initializer.v1", "b": 10},
}
func = my_registry.resolve({"test": cfg})["test"]
assert func(1) == 51
assert func(100) == 150
def test_validate_generator():
"""Test that generator replacement for validation in config doesn't
actually replace the returned value."""
@my_registry.schedules("test_schedule.v2")
def test_schedule():
while True:
yield 10
cfg = {"@schedules": "test_schedule.v2"}
result = my_registry.resolve({"test": cfg})["test"]
assert isinstance(result, GeneratorType)
@my_registry.optimizers("test_optimizer.v2")
def test_optimizer2(rate: Generator) -> Generator:
return rate
cfg = {
"@optimizers": "test_optimizer.v2",
"rate": {"@schedules": "test_schedule.v2"},
}
result = my_registry.resolve({"test": cfg})["test"]
assert isinstance(result, GeneratorType)
@my_registry.optimizers("test_optimizer.v3")
def test_optimizer3(schedules: Dict[str, Generator]) -> Generator:
return schedules["rate"]
cfg = {
"@optimizers": "test_optimizer.v3",
"schedules": {"rate": {"@schedules": "test_schedule.v2"}},
}
result = my_registry.resolve({"test": cfg})["test"]
assert isinstance(result, GeneratorType)
@my_registry.optimizers("test_optimizer.v4")
def test_optimizer4(*schedules: Generator) -> Generator:
return schedules[0]
def test_handle_generic_type():
"""Test that validation can handle checks against arbitrary generic
types in function argument annotations."""
cfg = {"@cats": "generic_cat.v1", "cat": {"@cats": "int_cat.v1", "value_in": 3}}
cat = my_registry.resolve({"test": cfg})["test"]
assert isinstance(cat, Cat)
assert cat.value_in == 3
assert cat.value_out is None
assert cat.name == "generic_cat"
@pytest.mark.parametrize(
"cfg",
[
"[a]\nb = 1\nc = 2\n\n[a.c]\nd = 3",
"[a]\nb = 1\n\n[a.c]\nd = 2\n\n[a.c.d]\ne = 3",
],
)
def test_handle_error_duplicate_keys(cfg):
"""This would cause very cryptic error when interpreting config.
(TypeError: 'X' object does not support item assignment)
"""
with pytest.raises(ConfigValidationError):
Config().from_str(cfg)
@pytest.mark.parametrize(
"cfg,is_valid",
[("[a]\nb = 1\n\n[a.c]\nd = 3", True), ("[a]\nb = 1\n\n[A.c]\nd = 2", False)],
)
def test_cant_expand_undefined_block(cfg, is_valid):
"""Test that you can't expand a block that hasn't been created yet. This
comes up when you typo a name, and if we allow expansion of undefined blocks,
it's very hard to create good errors for those typos.
"""
if is_valid:
Config().from_str(cfg)
else:
with pytest.raises(ConfigValidationError):
Config().from_str(cfg)
def test_fill_config_overrides():
config = {
"cfg": {
"one": 1,
"two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}},
}
}
overrides = {"cfg.two.three.evil": False}
result = my_registry.fill(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"]["evil"] is False
# Test that promises can be overwritten as well
overrides = {"cfg.two.three": 3}
result = my_registry.fill(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"] == 3
# Test that value can be overwritten with promises and that the result is
# interpreted and filled correctly
overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}}
result = my_registry.fill(config, overrides=overrides)
assert result["cfg"]["two"] is None
assert result["cfg"]["one"]["@cats"] == "catsie.v1"
assert result["cfg"]["one"]["evil"] is False
assert result["cfg"]["one"]["cute"] is True
# Overwriting with wrong types should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": 20}
my_registry.fill(config, overrides=overrides, validate=True)
# Overwriting with incomplete promises should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}}
my_registry.fill(config, overrides=overrides)
# Overrides that don't match config should raise error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": False, "two.four": True}
my_registry.fill(config, overrides=overrides, validate=True)
with pytest.raises(ConfigValidationError):
overrides = {"cfg.five": False}
my_registry.fill(config, overrides=overrides, validate=True)
def test_resolve_overrides():
config = {
"cfg": {
"one": 1,
"two": {"three": {"@cats": "catsie.v1", "evil": True, "cute": False}},
}
}
overrides = {"cfg.two.three.evil": False}
result = my_registry.resolve(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"] == "meow"
# Test that promises can be overwritten as well
overrides = {"cfg.two.three": 3}
result = my_registry.resolve(config, overrides=overrides, validate=True)
assert result["cfg"]["two"]["three"] == 3
# Test that value can be overwritten with promises
overrides = {"cfg": {"one": {"@cats": "catsie.v1", "evil": False}, "two": None}}
result = my_registry.resolve(config, overrides=overrides)
assert result["cfg"]["one"] == "meow"
assert result["cfg"]["two"] is None
# Overwriting with wrong types should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": 20}
my_registry.resolve(config, overrides=overrides, validate=True)
# Overwriting with incomplete promises should cause validation error
with pytest.raises(ConfigValidationError):
overrides = {"cfg": {"one": {"@cats": "catsie.v1"}, "two": None}}
my_registry.resolve(config, overrides=overrides)
# Overrides that don't match config should raise error
with pytest.raises(ConfigValidationError):
overrides = {"cfg.two.three.evil": False, "cfg.two.four": True}
my_registry.resolve(config, overrides=overrides, validate=True)
with pytest.raises(ConfigValidationError):
overrides = {"cfg.five": False}
my_registry.resolve(config, overrides=overrides, validate=True)
@pytest.mark.parametrize(
"prop,expected",
[("a.b.c", True), ("a.b", True), ("a", True), ("a.e", True), ("a.b.c.d", False)],
)
def test_is_in_config(prop, expected):
config = {"a": {"b": {"c": 5, "d": 6}, "e": [1, 2]}}
assert my_registry._is_in_config(prop, config) is expected
def test_resolve_prefilled_values():
class Language(object):
def __init__(self):
...
@my_registry.optimizers("prefilled.v1")
def prefilled(nlp: Language, value: int = 10):
return (nlp, value)
# Passing an instance of Language here via the config is bad, since it
# won't serialize to a string, but we still test for it
config = {"test": {"@optimizers": "prefilled.v1", "nlp": Language(), "value": 50}}
resolved = my_registry.resolve(config, validate=True)
result = resolved["test"]
assert isinstance(result[0], Language)
assert result[1] == 50
def test_fill_config_dict_return_type():
"""Test that a registered function returning a dict is handled correctly."""
@my_registry.cats.register("catsie_with_dict.v1")
def catsie_with_dict(evil: StrictBool) -> Dict[str, bool]:
return {"not_evil": not evil}
config = {"test": {"@cats": "catsie_with_dict.v1", "evil": False}, "foo": 10}
result = my_registry.fill({"cfg": config}, validate=True)["cfg"]["test"]
assert result["evil"] is False
assert "not_evil" not in result
result = my_registry.resolve({"cfg": config}, validate=True)["cfg"]["test"]
assert result["not_evil"] is True
def test_deepcopy_config():
numpy = pytest.importorskip("numpy")
config = Config({"a": 1, "b": {"c": 2, "d": 3}})
copied = config.copy()
# Same values but not same object
assert config == copied
assert config is not copied
@pytest.mark.skipif(
platform.python_implementation() == "PyPy", reason="copy does not fail for pypy"
)
def test_deepcopy_config_pickle():
numpy = pytest.importorskip("numpy")
# Check for error if value can't be pickled/deepcopied
config = Config({"a": 1, "b": numpy})
with pytest.raises(ValueError):
config.copy()
def test_config_to_str_simple_promises():
"""Test that references to function registries without arguments are
serialized inline as dict."""
config_str = """[section]\nsubsection = {"@registry":"value"}"""
config = Config().from_str(config_str)
assert config["section"]["subsection"]["@registry"] == "value"
assert config.to_str() == config_str
def test_config_from_str_invalid_section():
config_str = """[a]\nb = null\n\n[a.b]\nc = 1"""
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
config_str = """[a]\nb = null\n\n[a.b.c]\nd = 1"""
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
def test_config_to_str_order():
"""Test that Config.to_str orders the sections."""
config = {"a": {"b": {"c": 1, "d": 2}, "e": 3}, "f": {"g": {"h": {"i": 4, "j": 5}}}}
expected = (
"[a]\ne = 3\n\n[a.b]\nc = 1\nd = 2\n\n[f]\n\n[f.g]\n\n[f.g.h]\ni = 4\nj = 5"
)
config = Config(config)
assert config.to_str() == expected
@pytest.mark.parametrize("d", [".", ":"])
def test_config_interpolation(d):
"""Test that config values are interpolated correctly. The parametrized
value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
valid. The double {{ }} in the config strings are required to prevent the
references from being interpreted as an actual f-string variable.
"""
c_str = """[a]\nfoo = "hello"\n\n[b]\nbar = ${foo}"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}"""
assert Config().from_str(c_str)["b"]["bar"] == "hello"
c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = ${{a{d}foo}}!"""
assert Config().from_str(c_str)["b"]["bar"] == "hello!"
c_str = f"""[a]\nfoo = "hello"\n\n[b]\nbar = "${{a{d}foo}}!\""""
assert Config().from_str(c_str)["b"]["bar"] == "hello!"
c_str = f"""[a]\nfoo = 15\n\n[b]\nbar = ${{a{d}foo}}!"""
assert Config().from_str(c_str)["b"]["bar"] == "15!"
c_str = f"""[a]\nfoo = ["x", "y"]\n\n[b]\nbar = ${{a{d}foo}}"""
assert Config().from_str(c_str)["b"]["bar"] == ["x", "y"]
# Interpolation within the same section
c_str = f"""[a]\nfoo = "x"\nbar = ${{a{d}foo}}\nbaz = "${{a{d}foo}}y\""""
assert Config().from_str(c_str)["a"]["bar"] == "x"
assert Config().from_str(c_str)["a"]["baz"] == "xy"
def test_config_interpolation_lists():
# Test that lists are preserved correctly
c_str = """[a]\nb = 1\n\n[c]\nd = ["hello ${a.b}", "world"]"""
config = Config().from_str(c_str, interpolate=False)
assert config["c"]["d"] == ["hello ${a.b}", "world"]
config = config.interpolate()
assert config["c"]["d"] == ["hello 1", "world"]
c_str = """[a]\nb = 1\n\n[c]\nd = [${a.b}, "hello ${a.b}", "world"]"""
config = Config().from_str(c_str)
assert config["c"]["d"] == [1, "hello 1", "world"]
config = Config().from_str(c_str, interpolate=False)
# NOTE: This currently doesn't work, because we can't know how to JSON-load
# the uninterpolated list [${a.b}].
# assert config["c"]["d"] == ["${a.b}", "hello ${a.b}", "world"]
# config = config.interpolate()
# assert config["c"]["d"] == [1, "hello 1", "world"]
c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", ${a}]"""
config = Config().from_str(c_str)
assert config["c"]["d"] == ["hello", {"b": 1}]
c_str = """[a]\nb = 1\n\n[c]\nd = ["hello", "hello ${a}"]"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": ["hello ${a.b}"], "y": 2}]"""
config = Config().from_str(config_str)
assert config["c"]["d"] == ["hello", {"x": ["hello 1"], "y": 2}]
config_str = """[a]\nb = 1\n\n[c]\nd = ["hello", {"x": [${a.b}], "y": 2}]"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
@pytest.mark.parametrize("d", [".", ":"])
def test_config_interpolation_sections(d):
"""Test that config sections are interpolated correctly. The parametrized
value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
valid. The double {{ }} in the config strings are required to prevent the
references from being interpreted as an actual f-string variable.
"""
# Simple block references
c_str = """[a]\nfoo = "hello"\nbar = "world"\n\n[b]\nc = ${a}"""
config = Config().from_str(c_str)
assert config["b"]["c"] == config["a"]
# References with non-string values
c_str = f"""[a]\nfoo = "hello"\n\n[a.x]\ny = ${{a{d}b}}\n\n[a.b]\nc = 1\nd = [10]"""
config = Config().from_str(c_str)
assert config["a"]["x"]["y"] == config["a"]["b"]
# Multiple references in the same string
c_str = f"""[a]\nx = "string"\ny = 10\n\n[b]\nz = "${{a{d}x}}/${{a{d}y}}\""""
config = Config().from_str(c_str)
assert config["b"]["z"] == "string/10"
# Non-string references in string (converted to string)
c_str = f"""[a]\nx = ["hello", "world"]\n\n[b]\ny = "result: ${{a{d}x}}\""""
config = Config().from_str(c_str)
assert config["b"]["y"] == 'result: ["hello", "world"]'
# References to sections referencing sections
c_str = """[a]\nfoo = "x"\n\n[b]\nbar = ${a}\n\n[c]\nbaz = ${b}"""
config = Config().from_str(c_str)
assert config["b"]["bar"] == config["a"]
assert config["c"]["baz"] == config["b"]
# References to section values referencing other sections
c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b{d}bar}}"""
config = Config().from_str(c_str)
assert config["c"]["baz"] == config["b"]["bar"]
# References to sections with subsections
c_str = """[a]\nfoo = "x"\n\n[a.b]\nbar = 100\n\n[c]\nbaz = ${a}"""
config = Config().from_str(c_str)
assert config["c"]["baz"] == config["a"]
# Infinite recursion
c_str = """[a]\nfoo ="x"\n\n[a.b]\nbar = ${a}"""
config = Config().from_str(c_str)
assert config["a"]["b"]["bar"] == config["a"]
c_str = f"""[a]\nfoo = "x"\n\n[b]\nbar = ${{a}}\n\n[c]\nbaz = ${{b.bar{d}foo}}"""
# We can't reference not-yet interpolated subsections
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
# Generally invalid references
c_str = f"""[a]\nfoo = ${{b{d}bar}}"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
# We can't reference sections or promises within strings
c_str = """[a]\n\n[a.b]\nfoo = "x: ${c}"\n\n[c]\nbar = 1\nbaz = 2"""
with pytest.raises(ConfigValidationError):
Config().from_str(c_str)
def test_config_from_str_overrides():
config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = 3\n\n[f]\ng = {"x": "y"}"""
# Basic value substitution
overrides = {"a.b": 10, "a.c.d": 20}
config = Config().from_str(config_str, overrides=overrides)
assert config["a"]["b"] == 10
assert config["a"]["c"]["d"] == 20
assert config["a"]["c"]["e"] == 3
# Valid values that previously weren't in config
config = Config().from_str(config_str, overrides={"a.c.f": 100})
assert config["a"]["c"]["d"] == 2
assert config["a"]["c"]["e"] == 3
assert config["a"]["c"]["f"] == 100
# Invalid keys and sections
with pytest.raises(ConfigValidationError):
Config().from_str(config_str, overrides={"f": 10})
# This currently isn't expected to work, because the dict in f.g is not
# interpreted as a section while the config is still just the configparser
with pytest.raises(ConfigValidationError):
Config().from_str(config_str, overrides={"f.g.x": "z"})
# With variables (values)
config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\ne = ${a:b}"""
config = Config().from_str(config_str, overrides={"a.b": 10})
assert config["a"]["b"] == 10
assert config["a"]["c"]["e"] == 10
# With variables (sections)
config_str = """[a]\nb = 1\n\n[a.c]\nd = 2\n[e]\nf = ${a.c}"""
config = Config().from_str(config_str, overrides={"a.c.d": 20})
assert config["a"]["c"]["d"] == 20
assert config["e"]["f"] == {"d": 20}
def test_config_reserved_aliases():
"""Test that the auto-generated pydantic schemas auto-alias reserved
attributes like "validate" that would otherwise cause NameError."""
@my_registry.cats("catsie.with_alias")
def catsie_with_alias(validate: StrictBool = False):
return validate
cfg = {"@cats": "catsie.with_alias", "validate": True}
resolved = my_registry.resolve({"test": cfg})
filled = my_registry.fill({"test": cfg})
assert resolved["test"] is True
assert filled["test"] == cfg
cfg = {"@cats": "catsie.with_alias", "validate": 20}
with pytest.raises(ConfigValidationError):
my_registry.resolve({"test": cfg})
@pytest.mark.parametrize("d", [".", ":"])
def test_config_no_interpolation(d):
"""Test that interpolation is correctly preserved. The parametrized
value is the final divider (${a.b} vs. ${a:b}). Both should now work and be
valid. The double {{ }} in the config strings are required to prevent the
references from being interpreted as an actual f-string variable.
"""
numpy = pytest.importorskip("numpy")
c_str = f"""[a]\nb = 1\n\n[c]\nd = ${{a{d}b}}\ne = \"hello${{a{d}b}}"\nf = ${{a}}"""
config = Config().from_str(c_str, interpolate=False)
assert not config.is_interpolated
assert config["c"]["d"] == f"${{a{d}b}}"
assert config["c"]["e"] == f'"hello${{a{d}b}}"'
assert config["c"]["f"] == "${a}"
config2 = Config().from_str(config.to_str(), interpolate=True)
assert config2.is_interpolated
assert config2["c"]["d"] == 1
assert config2["c"]["e"] == "hello1"
assert config2["c"]["f"] == {"b": 1}
config3 = config.interpolate()
assert config3.is_interpolated
assert config3["c"]["d"] == 1
assert config3["c"]["e"] == "hello1"
assert config3["c"]["f"] == {"b": 1}
# Bad non-serializable value
cfg = {"x": {"y": numpy.asarray([[1, 2], [4, 5]], dtype="f"), "z": f"${{x{d}y}}"}}
with pytest.raises(ConfigValidationError):
Config(cfg).interpolate()
def test_config_no_interpolation_registry():
config_str = """[a]\nbad = true\n[b]\n@cats = "catsie.v1"\nevil = ${a:bad}\n\n[c]\n d = ${b}"""
config = Config().from_str(config_str, interpolate=False)
assert not config.is_interpolated
assert config["b"]["evil"] == "${a:bad}"
assert config["c"]["d"] == "${b}"
filled = my_registry.fill(config)
resolved = my_registry.resolve(config)
assert resolved["b"] == "scratch!"
assert resolved["c"]["d"] == "scratch!"
assert filled["b"]["evil"] == "${a:bad}"
assert filled["b"]["cute"] is True
assert filled["c"]["d"] == "${b}"
interpolated = filled.interpolate()
assert interpolated.is_interpolated
assert interpolated["b"]["evil"] is True
assert interpolated["c"]["d"] == interpolated["b"]
config = Config().from_str(config_str, interpolate=True)
assert config.is_interpolated
filled = my_registry.fill(config)
resolved = my_registry.resolve(config)
assert resolved["b"] == "scratch!"
assert resolved["c"]["d"] == "scratch!"
assert filled["b"]["evil"] is True
assert filled["c"]["d"] == filled["b"]
# Resolving a non-interpolated filled config
config = Config().from_str(config_str, interpolate=False)
assert not config.is_interpolated
filled = my_registry.fill(config)
assert not filled.is_interpolated
assert filled["c"]["d"] == "${b}"
resolved = my_registry.resolve(filled)
assert resolved["c"]["d"] == "scratch!"
def test_config_deep_merge():
config = {"a": "hello", "b": {"c": "d"}}
defaults = {"a": "world", "b": {"c": "e", "f": "g"}}
merged = Config(defaults).merge(config)
assert len(merged) == 2
assert merged["a"] == "hello"
assert merged["b"] == {"c": "d", "f": "g"}
config = {"a": "hello", "b": {"@test": "x", "foo": 1}}
defaults = {"a": "world", "b": {"@test": "x", "foo": 100, "bar": 2}, "c": 100}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@test": "x", "foo": 1, "bar": 2}
assert merged["c"] == 100
config = {"a": "hello", "b": {"@test": "x", "foo": 1}, "c": 100}
defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@test": "x", "foo": 1}
assert merged["c"] == 100
# Test that leaving out the factory just adds to existing
config = {"a": "hello", "b": {"foo": 1}, "c": 100}
defaults = {"a": "world", "b": {"@test": "y", "foo": 100, "bar": 2}}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@test": "y", "foo": 1, "bar": 2}
assert merged["c"] == 100
# Test that switching to a different factory prevents the default from being added
config = {"a": "hello", "b": {"@foo": 1}, "c": 100}
defaults = {"a": "world", "b": {"@bar": "y"}}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@foo": 1}
assert merged["c"] == 100
config = {"a": "hello", "b": {"@foo": 1}, "c": 100}
defaults = {"a": "world", "b": "y"}
merged = Config(defaults).merge(config)
assert len(merged) == 3
assert merged["a"] == "hello"
assert merged["b"] == {"@foo": 1}
assert merged["c"] == 100
def test_config_deep_merge_variables():
config_str = """[a]\nb= 1\nc = 2\n\n[d]\ne = ${a:b}"""
defaults_str = """[a]\nx = 100\n\n[d]\ny = 500"""
config = Config().from_str(config_str, interpolate=False)
defaults = Config().from_str(defaults_str)
merged = defaults.merge(config)
assert merged["a"] == {"b": 1, "c": 2, "x": 100}
assert merged["d"] == {"e": "${a:b}", "y": 500}
assert merged.interpolate()["d"] == {"e": 1, "y": 500}
# With variable in defaults: overwritten by new value
config = Config().from_str("""[a]\nb= 1\nc = 2""")
defaults = Config().from_str("""[a]\nb = 100\nc = ${a:b}""", interpolate=False)
merged = defaults.merge(config)
assert merged["a"]["c"] == 2
def test_config_to_str_roundtrip():
numpy = pytest.importorskip("numpy")
cfg = {"cfg": {"foo": False}}
config_str = Config(cfg).to_str()
assert config_str == "[cfg]\nfoo = false"
config = Config().from_str(config_str)
assert dict(config) == cfg
cfg = {"cfg": {"foo": "false"}}
config_str = Config(cfg).to_str()
assert config_str == '[cfg]\nfoo = "false"'
config = Config().from_str(config_str)
assert dict(config) == cfg
# Bad non-serializable value
cfg = {"cfg": {"x": numpy.asarray([[1, 2, 3, 4], [4, 5, 3, 4]], dtype="f")}}
config = Config(cfg)
with pytest.raises(ConfigValidationError):
config.to_str()
# Roundtrip with variables: preserve variables correctly (quoted/unquoted)
config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = "${a:b}\""""
config = Config().from_str(config_str, interpolate=False)
assert config.to_str() == config_str
def test_config_is_interpolated():
"""Test that a config object correctly reports whether it's interpolated."""
config_str = """[a]\nb = 1\n\n[c]\nd = ${a:b}\ne = \"hello${a:b}"\nf = ${a}"""
config = Config().from_str(config_str, interpolate=False)
assert not config.is_interpolated
config = config.merge(Config({"x": {"y": "z"}}))
assert not config.is_interpolated
config = Config(config)
assert not config.is_interpolated
config = config.interpolate()
assert config.is_interpolated
config = config.merge(Config().from_str(config_str, interpolate=False))
assert not config.is_interpolated
@pytest.mark.parametrize(
"section_order,expected_str,expected_keys",
[
# fmt: off
([], "[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[h]\ni = 5\n\n[j]\nk = 6", ["a", "h", "j"]),
(["j", "h", "a"], "[j]\nk = 6\n\n[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4", ["j", "h", "a"]),
(["h"], "[h]\ni = 5\n\n[a]\nb = 1\nc = 2\n\n[a.d]\ne = 3\n\n[a.f]\ng = 4\n\n[j]\nk = 6", ["h", "a", "j"])
# fmt: on
],
)
def test_config_serialize_custom_sort(section_order, expected_str, expected_keys):
cfg = {
"j": {"k": 6},
"a": {"b": 1, "d": {"e": 3}, "c": 2, "f": {"g": 4}},
"h": {"i": 5},
}
cfg_str = Config(cfg).to_str()
assert Config(cfg, section_order=section_order).to_str() == expected_str
keys = list(Config(section_order=section_order).from_str(cfg_str).keys())
assert keys == expected_keys
keys = list(Config(cfg, section_order=section_order).keys())
assert keys == expected_keys
def test_config_custom_sort_preserve():
"""Test that sort order is preserved when merging and copying configs,
or when configs are filled and resolved."""
cfg = {"x": {}, "y": {}, "z": {}}
section_order = ["y", "z", "x"]
expected = "[y]\n\n[z]\n\n[x]"
config = Config(cfg, section_order=section_order)
assert config.to_str() == expected
config2 = config.copy()
assert config2.to_str() == expected
config3 = config.merge({"a": {}})
assert config3.to_str() == f"{expected}\n\n[a]"
config4 = Config(config)
assert config4.to_str() == expected
config_str = """[a]\nb = 1\n[c]\n@cats = "catsie.v1"\nevil = true\n\n[t]\n x = 2"""
section_order = ["c", "a", "t"]
config5 = Config(section_order=section_order).from_str(config_str)
assert list(config5.keys()) == section_order
filled = my_registry.fill(config5)
assert filled.section_order == section_order
def test_config_pickle():
config = Config({"foo": "bar"}, section_order=["foo", "bar", "baz"])
data = pickle.dumps(config)
config_new = pickle.loads(data)
assert config_new == {"foo": "bar"}
assert config_new.section_order == ["foo", "bar", "baz"]
def test_config_fill_extra_fields():
"""Test that filling a config from a schema removes extra fields."""
class TestSchemaContent(BaseModel):
a: str
b: int
class Config:
extra = "forbid"
class TestSchema(BaseModel):
cfg: TestSchemaContent
config = Config({"cfg": {"a": "1", "b": 2, "c": True}})
with pytest.raises(ConfigValidationError):
my_registry.fill(config, schema=TestSchema)
filled = my_registry.fill(config, schema=TestSchema, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2}
config2 = config.interpolate()
filled = my_registry.fill(config2, schema=TestSchema, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2}
config3 = Config({"cfg": {"a": "1", "b": 2, "c": True}}, is_interpolated=False)
filled = my_registry.fill(config3, schema=TestSchema, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2}
class TestSchemaContent2(BaseModel):
a: str
b: int
class Config:
extra = "allow"
class TestSchema2(BaseModel):
cfg: TestSchemaContent2
filled = my_registry.fill(config, schema=TestSchema2, validate=False)["cfg"]
assert filled == {"a": "1", "b": 2, "c": True}
def test_config_validation_error_custom():
class Schema(BaseModel):
hello: int
world: int
config = {"hello": 1, "world": "hi!"}
with pytest.raises(ConfigValidationError) as exc_info:
my_registry._fill(config, Schema)
e1 = exc_info.value
assert e1.title == "Config validation error"
assert e1.desc is None
assert not e1.parent
assert e1.show_config is True
assert len(e1.errors) == 1
assert e1.errors[0]["loc"] == ("world",)
assert e1.errors[0]["msg"] == "value is not a valid integer"
assert e1.errors[0]["type"] == "type_error.integer"
assert e1.error_types == set(["type_error.integer"])
# Create a new error with overrides
title = "Custom error"
desc = "Some error description here"
e2 = ConfigValidationError.from_error(e1, title=title, desc=desc, show_config=False)
assert e2.errors == e1.errors
assert e2.error_types == e1.error_types
assert e2.title == title
assert e2.desc == desc
assert e2.show_config is False
assert e1.text != e2.text
def test_config_parsing_error():
config_str = "[a]\nb c"
with pytest.raises(ConfigValidationError):
Config().from_str(config_str)
def test_config_fill_without_resolve():
class BaseSchema(BaseModel):
catsie: int
config = {"catsie": {"@cats": "catsie.v1", "evil": False}}
filled = my_registry.fill(config)
resolved = my_registry.resolve(config)
assert resolved["catsie"] == "meow"
assert filled["catsie"]["cute"] is True
with pytest.raises(ConfigValidationError):
my_registry.resolve(config, schema=BaseSchema)
filled2 = my_registry.fill(config, schema=BaseSchema)
assert filled2["catsie"]["cute"] is True
resolved = my_registry.resolve(filled2)
assert resolved["catsie"] == "meow"
# With unavailable function
class BaseSchema2(BaseModel):
catsie: Any
other: int = 12
config = {"catsie": {"@cats": "dog", "evil": False}}
filled3 = my_registry.fill(config, schema=BaseSchema2)
assert filled3["catsie"] == config["catsie"]
assert filled3["other"] == 12
def test_config_dataclasses():
cat = Cat("testcat", value_in=1, value_out=2)
config = {"cfg": {"@cats": "catsie.v3", "arg": cat}}
result = my_registry.resolve(config)["cfg"]
assert isinstance(result, Cat)
assert result.name == cat.name
assert result.value_in == cat.value_in
assert result.value_out == cat.value_out
@pytest.mark.parametrize(
"greeting,value,expected",
[
# simple substitution should go fine
[342, "${vars.a}", int],
["342", "${vars.a}", str],
["everyone", "${vars.a}", str],
],
)
def test_config_interpolates(greeting, value, expected):
str_cfg = f"""
[project]
my_par = {value}
[vars]
a = "something"
"""
overrides = {"vars.a": greeting}
cfg = Config().from_str(str_cfg, overrides=overrides)
assert type(cfg["project"]["my_par"]) == expected
@pytest.mark.parametrize(
"greeting,value,expected",
[
# fmt: off
# simple substitution should go fine
["hello 342", "${vars.a}", "hello 342"],
["hello everyone", "${vars.a}", "hello everyone"],
["hello tout le monde", "${vars.a}", "hello tout le monde"],
["hello 42", "${vars.a}", "hello 42"],
# substituting an element in a list
["hello 342", "[1, ${vars.a}, 3]", "hello 342"],
["hello everyone", "[1, ${vars.a}, 3]", "hello everyone"],
["hello tout le monde", "[1, ${vars.a}, 3]", "hello tout le monde"],
["hello 42", "[1, ${vars.a}, 3]", "hello 42"],
# substituting part of a string
[342, "hello ${vars.a}", "hello 342"],
["everyone", "hello ${vars.a}", "hello everyone"],
["tout le monde", "hello ${vars.a}", "hello tout le monde"],
pytest.param("42", "hello ${vars.a}", "hello 42", marks=pytest.mark.xfail),
# substituting part of a implicit string inside a list
[342, "[1, hello ${vars.a}, 3]", "hello 342"],
["everyone", "[1, hello ${vars.a}, 3]", "hello everyone"],
["tout le monde", "[1, hello ${vars.a}, 3]", "hello tout le monde"],
pytest.param("42", "[1, hello ${vars.a}, 3]", "hello 42", marks=pytest.mark.xfail),
# substituting part of a explicit string inside a list
[342, "[1, 'hello ${vars.a}', '3']", "hello 342"],
["everyone", "[1, 'hello ${vars.a}', '3']", "hello everyone"],
["tout le monde", "[1, 'hello ${vars.a}', '3']", "hello tout le monde"],
pytest.param("42", "[1, 'hello ${vars.a}', '3']", "hello 42", marks=pytest.mark.xfail),
# more complicated example
[342, "[{'name':'x','script':['hello ${vars.a}']}]", "hello 342"],
["everyone", "[{'name':'x','script':['hello ${vars.a}']}]", "hello everyone"],
["tout le monde", "[{'name':'x','script':['hello ${vars.a}']}]", "hello tout le monde"],
pytest.param("42", "[{'name':'x','script':['hello ${vars.a}']}]", "hello 42", marks=pytest.mark.xfail),
# fmt: on
],
)
def test_config_overrides(greeting, value, expected):
str_cfg = f"""
[project]
commands = {value}
[vars]
a = "world"
"""
overrides = {"vars.a": greeting}
assert "${vars.a}" in str_cfg
cfg = Config().from_str(str_cfg, overrides=overrides)
assert expected in str(cfg)
confection-0.0.4/confection/tests/util.py 0000664 0000000 0000000 00000007053 14357240256 0020477 0 ustar 00root root 0000000 0000000 """
Registered functions used for config tests.
"""
import contextlib
import dataclasses
import shutil
import tempfile
from pathlib import Path
from typing import (
Iterable,
List,
Union,
Generator,
Generic,
TypeVar,
Optional,
)
from pydantic.types import StrictBool
import catalogue
import confection
FloatOrSeq = Union[float, List[float], Generator]
InT = TypeVar("InT")
OutT = TypeVar("OutT")
@dataclasses.dataclass
class Cat(Generic[InT, OutT]):
name: str
value_in: InT
value_out: OutT
my_registry_namespace = "config_tests"
class my_registry(confection.registry):
namespace = "config_tests"
cats = catalogue.create(namespace, "cats", entry_points=False)
optimizers = catalogue.create(namespace, "optimizers", entry_points=False)
schedules = catalogue.create(namespace, "schedules", entry_points=False)
initializers = catalogue.create(namespace, "initializers", entry_points=False)
layers = catalogue.create(namespace, "layers", entry_points=False)
@my_registry.cats.register("catsie.v1")
def catsie_v1(evil: StrictBool, cute: bool = True) -> str:
if evil:
return "scratch!"
else:
return "meow"
@my_registry.cats.register("catsie.v2")
def catsie_v2(evil: StrictBool, cute: bool = True, cute_level: int = 1) -> str:
if evil:
return "scratch!"
else:
if cute_level > 2:
return "meow <3"
return "meow"
@my_registry.cats("catsie.v3")
def catsie(arg: Cat) -> Cat:
return arg
@my_registry.optimizers("Adam.v1")
def Adam(
learn_rate: FloatOrSeq = 0.001,
*,
beta1: FloatOrSeq = 0.001,
beta2: FloatOrSeq = 0.001,
use_averages: bool = True,
):
"""
Mocks optimizer generation. Note that the returned object is not actually an optimizer. This function is merely used
to illustrate how to use the function registry, e.g. with thinc.
"""
@dataclasses.dataclass
class Optimizer:
learn_rate: FloatOrSeq
beta1: FloatOrSeq
beta2: FloatOrSeq
use_averages: bool
return Optimizer(
learn_rate=learn_rate, beta1=beta1, beta2=beta2, use_averages=use_averages
)
@my_registry.schedules("warmup_linear.v1")
def warmup_linear(
initial_rate: float, warmup_steps: int, total_steps: int
) -> Iterable[float]:
"""Generate a series, starting from an initial rate, and then with a warmup
period, and then a linear decline. Used for learning rates.
"""
step = 0
while True:
if step < warmup_steps:
factor = step / max(1, warmup_steps)
else:
factor = max(
0.0, (total_steps - step) / max(1.0, total_steps - warmup_steps)
)
yield factor * initial_rate
step += 1
@my_registry.cats("generic_cat.v1")
def generic_cat(cat: Cat[int, int]) -> Cat[int, int]:
cat.name = "generic_cat"
return cat
@my_registry.cats("int_cat.v1")
def int_cat(
value_in: Optional[int] = None, value_out: Optional[int] = None
) -> Cat[Optional[int], Optional[int]]:
"""Instantiates cat with integer values."""
return Cat(name="int_cat", value_in=value_in, value_out=value_out)
@my_registry.optimizers.register("my_cool_optimizer.v1")
def make_my_optimizer(learn_rate: List[float], beta1: float):
return Adam(learn_rate, beta1=beta1)
@my_registry.schedules("my_cool_repetitive_schedule.v1")
def decaying(base_rate: float, repeat: int) -> List[float]:
return repeat * [base_rate]
@contextlib.contextmanager
def make_tempdir():
d = Path(tempfile.mkdtemp())
yield d
shutil.rmtree(str(d))
confection-0.0.4/confection/util.py 0000664 0000000 0000000 00000002715 14357240256 0017335 0 ustar 00root root 0000000 0000000 import functools
import sys
from typing import TypeVar, Callable, Any, Iterator
if sys.version_info < (3, 8):
# Ignoring type for mypy to avoid "Incompatible import" error (https://github.com/python/mypy/issues/4427).
from typing_extensions import Protocol # type: ignore
else:
from typing import Protocol
_DIn = TypeVar("_DIn")
class Decorator(Protocol):
"""Protocol to mark a function as returning its child with identical signature."""
def __call__(self, name: str) -> Callable[[_DIn], _DIn]:
...
# This is how functools.partials seems to do it, too, to retain the return type
PartialT = TypeVar("PartialT")
def partial(
func: Callable[..., PartialT], *args: Any, **kwargs: Any
) -> Callable[..., PartialT]:
"""Wrapper around functools.partial that retains docstrings and can include
other workarounds if needed.
"""
partial_func = functools.partial(func, *args, **kwargs)
partial_func.__doc__ = func.__doc__
return partial_func
class Generator(Iterator):
"""Custom generator type. Used to annotate function arguments that accept
generators so they can be validated by pydantic (which doesn't support
iterators/iterables otherwise).
"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if not hasattr(v, "__iter__") and not hasattr(v, "__next__"):
raise TypeError("not a valid iterator")
return v
confection-0.0.4/pyproject.toml 0000664 0000000 0000000 00000000126 14357240256 0016565 0 ustar 00root root 0000000 0000000 [build-system]
requires = [
"setuptools",
]
build-backend = "setuptools.build_meta"
confection-0.0.4/requirements.txt 0000664 0000000 0000000 00000000477 14357240256 0017146 0 ustar 00root root 0000000 0000000 pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0
typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8"
srsly>=2.4.0,<3.0.0
# Development requirements
pathy>=0.3.5
pytest>=5.2.0,!=7.1.0
mypy>=0.980,<0.990; platform_machine != 'aarch64' and python_version >= '3.7'
types-dataclasses>=0.1.3; python_version < '3.7'
numpy>=1.15.0
confection-0.0.4/setup.cfg 0000664 0000000 0000000 00000002566 14357240256 0015504 0 ustar 00root root 0000000 0000000 [metadata]
version = 0.0.4
description = The sweetest config system for Python
url = https://github.com/explosion/confection
author = Explosion
author_email = contact@explosion.ai
license = MIT
long_description = file: README.md
long_description_content_type = text/markdown
classifiers =
Development Status :: 5 - Production/Stable
Environment :: Console
Intended Audience :: Developers
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Operating System :: POSIX :: Linux
Operating System :: MacOS :: MacOS X
Operating System :: Microsoft :: Windows
Programming Language :: Python :: 3
Programming Language :: Python :: 3.6
Programming Language :: Python :: 3.7
Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Topic :: Scientific/Engineering
[options]
zip_safe = true
include_package_data = true
python_requires = >=3.6
install_requires =
pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0
typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8"
srsly>=2.4.0,<3.0.0
[sdist]
formats = gztar
[flake8]
ignore = E203, E266, E501, E731, W503
max-line-length = 80
select = B,C,E,F,W,T4,B9
exclude =
.env,
.git,
__pycache__,
[mypy]
ignore_missing_imports = True
no_implicit_optional = True
confection-0.0.4/setup.py 0000664 0000000 0000000 00000000232 14357240256 0015361 0 ustar 00root root 0000000 0000000 #!/usr/bin/env python
if __name__ == "__main__":
from setuptools import setup, find_packages
setup(name="confection", packages=find_packages())