pax_global_header 0000666 0000000 0000000 00000000064 14742310675 0014523 g ustar 00root root 0000000 0000000 52 comment=6c0fe81e2cd30093aa3bb12b9f4df67e4da1ff59
srsly-release-v2.5.1/ 0000775 0000000 0000000 00000000000 14742310675 0014510 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/.buildkite/ 0000775 0000000 0000000 00000000000 14742310675 0016542 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/.buildkite/sdist.yml 0000664 0000000 0000000 00000000161 14742310675 0020411 0 ustar 00root root 0000000 0000000 steps:
-
command: "./bin/build-sdist.sh"
label: ":dizzy: :python:"
artifact_paths: "dist/*.tar.gz"
srsly-release-v2.5.1/.github/ 0000775 0000000 0000000 00000000000 14742310675 0016050 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/.github/workflows/ 0000775 0000000 0000000 00000000000 14742310675 0020105 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/.github/workflows/cibuildwheel.yml 0000664 0000000 0000000 00000005322 14742310675 0023272 0 ustar 00root root 0000000 0000000 name: Build
on:
push:
tags:
# ytf did they invent their own syntax that's almost regex?
# ** matches 'zero or more of any character'
- 'release-v[0-9]+.[0-9]+.[0-9]+**'
- 'prerelease-v[0-9]+.[0-9]+.[0-9]+**'
jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
# macos-13 is an intel runner, macos-14 is apple silicon
os: [ubuntu-latest, windows-latest, macos-13, macos-14, ubuntu-24.04-arm]
steps:
- uses: actions/checkout@v4
- name: Build wheels
uses: pypa/cibuildwheel@v2.21.3
env:
CIBW_SOME_OPTION: value
with:
package-dir: .
output-dir: wheelhouse
config-file: "{package}/pyproject.toml"
- uses: actions/upload-artifact@v4
with:
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
path: ./wheelhouse/*.whl
build_sdist:
name: Build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build sdist
run: pipx run build --sdist
- uses: actions/upload-artifact@v4
with:
name: cibw-sdist
path: dist/*.tar.gz
create_release:
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
permissions:
contents: write
checks: write
actions: read
issues: read
packages: write
pull-requests: read
repository-projects: read
statuses: read
steps:
- name: Get the tag name and determine if it's a prerelease
id: get_tag_info
run: |
FULL_TAG=${GITHUB_REF#refs/tags/}
if [[ $FULL_TAG == release-* ]]; then
TAG_NAME=${FULL_TAG#release-}
IS_PRERELEASE=false
elif [[ $FULL_TAG == prerelease-* ]]; then
TAG_NAME=${FULL_TAG#prerelease-}
IS_PRERELEASE=true
else
echo "Tag does not match expected patterns" >&2
exit 1
fi
echo "FULL_TAG=$TAG_NAME" >> $GITHUB_ENV
echo "TAG_NAME=$TAG_NAME" >> $GITHUB_ENV
echo "IS_PRERELEASE=$IS_PRERELEASE" >> $GITHUB_ENV
- uses: actions/download-artifact@v4
with:
# unpacks all CIBW artifacts into dist/
pattern: cibw-*
path: dist
merge-multiple: true
- name: Create Draft Release
id: create_release
uses: softprops/action-gh-release@v2
if: startsWith(github.ref, 'refs/tags/')
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
name: ${{ env.TAG_NAME }}
draft: true
prerelease: ${{ env.IS_PRERELEASE }}
files: "./dist/*"
srsly-release-v2.5.1/.github/workflows/publish_pypi.yml 0000664 0000000 0000000 00000001660 14742310675 0023342 0 ustar 00root root 0000000 0000000 # The cibuildwheel action triggers on creation of a release, this
# triggers on publication.
# The expected workflow is to create a draft release and let the wheels
# upload, and then hit 'publish', which uploads to PyPi.
on:
release:
types:
- published
jobs:
upload_pypi:
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/srsly
permissions:
id-token: write
contents: read
if: github.event_name == 'release' && github.event.action == 'published'
# or, alternatively, upload to PyPI on every tag starting with 'v' (remove on: release above to use this)
# if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
steps:
- uses: robinraju/release-downloader@v1
with:
tag: ${{ github.event.release.tag_name }}
fileName: '*'
out-file-path: 'dist'
- uses: pypa/gh-action-pypi-publish@release/v1
srsly-release-v2.5.1/.github/workflows/tests.yml 0000664 0000000 0000000 00000003717 14742310675 0022002 0 ustar 00root root 0000000 0000000 name: tests
on:
push:
tags-ignore:
- '**'
paths-ignore:
- "*.md"
- ".github/cibuildwheel.yml"
- ".github/publish_pypi.yml"
pull_request:
types: [opened, synchronize, reopened, edited]
paths-ignore:
- "*.md"
- ".github/cibuildwheel.yml"
- ".github/publish_pypi.yml"
env:
MODULE_NAME: 'srsly'
RUN_MYPY: 'false'
jobs:
tests:
name: Test
if: github.repository_owner == 'explosion'
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
python_version: ["3.9", "3.10", "3.11", "3.12"]
runs-on: ${{ matrix.os }}
steps:
- name: Check out repo
uses: actions/checkout@v3
- name: Configure Python version
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python_version }}
architecture: x64
- name: Build sdist
run: |
python -m pip install -U build pip setuptools
python -m pip install -U -r requirements.txt
python -m build --sdist
- name: Run mypy
shell: bash
if: ${{ env.RUN_MYPY == 'true' }}
run: |
python -m mypy $MODULE_NAME
- name: Delete source directory
shell: bash
run: |
rm -rf $MODULE_NAME
- name: Uninstall all packages
run: |
python -m pip freeze > installed.txt
python -m pip uninstall -y -r installed.txt
- name: Install from sdist
shell: bash
run: |
SDIST=$(python -c "import os;print(os.listdir('./dist')[-1])" 2>&1)
python -m pip install dist/$SDIST
- name: Test import
shell: bash
run: |
python -c "import $MODULE_NAME" -Werror
- name: Install test requirements
run: |
python -m pip install -U -r requirements.txt
- name: Run tests
shell: bash
run: |
python -m pytest --pyargs $MODULE_NAME -Werror
srsly-release-v2.5.1/.gitignore 0000664 0000000 0000000 00000002416 14742310675 0016503 0 ustar 00root root 0000000 0000000 .env/
.env*
.vscode/
cythonize.json
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# Cython intermediate files
*.cpp
# Vim files
*.sw*
srsly-release-v2.5.1/LICENSE 0000664 0000000 0000000 00000002117 14742310675 0015516 0 ustar 00root root 0000000 0000000 The MIT License (MIT)
Copyright (C) 2018 ExplosionAI UG (haftungsbeschränkt)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
srsly-release-v2.5.1/MANIFEST.in 0000664 0000000 0000000 00000000140 14742310675 0016241 0 ustar 00root root 0000000 0000000 recursive-include srsly *.h *.pyx *.pxd *.cc *.c *.cpp *.json
include LICENSE
include README.md
srsly-release-v2.5.1/README.md 0000664 0000000 0000000 00000044610 14742310675 0015774 0 ustar 00root root 0000000 0000000
# srsly: Modern high-performance serialization utilities for Python
This package bundles some of the best Python serialization libraries into one
standalone package, with a high-level API that makes it easy to write code
that's correct across platforms and Pythons. This allows us to provide all the
serialization utilities we need in a single binary wheel. Currently supports
**JSON**, **JSONL**, **MessagePack**, **Pickle** and **YAML**.
[](https://github.com/explosion/srsly/actions/workflows/tests.yml)
[](https://pypi.python.org/pypi/srsly)
[](https://anaconda.org/conda-forge/srsly)
[](https://github.com/explosion/srsly)
[](https://github.com/explosion/wheelwright/releases)
## Motivation
Serialization is hard, especially across Python versions and multiple platforms.
After dealing with many subtle bugs over the years (encodings, locales, large
files) our libraries like [spaCy](https://github.com/explosion/spaCy) and
[Prodigy](https://prodi.gy) had steadily grown a number of utility functions to
wrap the multiple serialization formats we need to support (especially `json`,
`msgpack` and `pickle`). These wrapping functions ended up duplicated across our
codebases, so we wanted to put them in one place.
At the same time, we noticed that having a lot of small dependencies was making
maintenance harder, and making installation slower. To solve this, we've made
`srsly` standalone, by including the component packages directly within it. This
way we can provide all the serialization utilities we need in a single binary
wheel.
`srsly` currently includes forks of the following packages:
- [`ujson`](https://github.com/esnme/ultrajson)
- [`msgpack`](https://github.com/msgpack/msgpack-python)
- [`msgpack-numpy`](https://github.com/lebedov/msgpack-numpy)
- [`cloudpickle`](https://github.com/cloudpipe/cloudpickle)
- [`ruamel.yaml`](https://github.com/pycontribs/ruamel-yaml) (without unsafe
implementations!)
## Installation
> ⚠️ Note that `v2.x` is only compatible with **Python 3.6+**. For 2.7+
> compatibility, use `v1.x`.
`srsly` can be installed from pip. Before installing, make sure that your `pip`,
`setuptools` and `wheel` are up to date.
```bash
python -m pip install -U pip setuptools wheel
python -m pip install srsly
```
Or from conda via conda-forge:
```bash
conda install -c conda-forge srsly
```
Alternatively, you can also compile the library from source. You'll need to make
sure that you have a development environment with a Python distribution
including header files, a compiler (XCode command-line tools on macOS / OS X or
Visual C++ build tools on Windows), pip and git installed.
Install from source:
```bash
# clone the repo
git clone https://github.com/explosion/srsly
cd srsly
# create a virtual environment
python -m venv .env
source .env/bin/activate
# update pip
python -m pip install -U pip setuptools wheel
# compile and install from source
python -m pip install .
```
For developers, install requirements separately and then install in editable
mode without build isolation:
```bash
# install in editable mode
python -m pip install -r requirements.txt
python -m pip install --no-build-isolation --editable .
# run test suite
python -m pytest --pyargs srsly
```
## API
### JSON
> 📦 The underlying module is exposed via `srsly.ujson`. However, we normally
> interact with it via the utility functions only.
#### function `srsly.json_dumps`
Serialize an object to a JSON string. Falls back to `json` if `sort_keys=True`
is used (until it's fixed in `ujson`).
```python
data = {"foo": "bar", "baz": 123}
json_string = srsly.json_dumps(data)
```
| Argument | Type | Description |
| ----------- | ---- | ------------------------------------------------------ |
| `data` | - | The JSON-serializable data to output. |
| `indent` | int | Number of spaces used to indent JSON. Defaults to `0`. |
| `sort_keys` | bool | Sort dictionary keys. Defaults to `False`. |
| **RETURNS** | str | The serialized string. |
#### function `srsly.json_loads`
Deserialize unicode or bytes to a Python object.
```python
data = '{"foo": "bar", "baz": 123}'
obj = srsly.json_loads(data)
```
| Argument | Type | Description |
| ----------- | ----------- | ------------------------------- |
| `data` | str / bytes | The data to deserialize. |
| **RETURNS** | - | The deserialized Python object. |
#### function `srsly.write_json`
Create a JSON file and dump contents or write to standard output.
```python
data = {"foo": "bar", "baz": 123}
srsly.write_json("/path/to/file.json", data)
```
| Argument | Type | Description |
| -------- | ------------ | ------------------------------------------------------ |
| `path` | str / `Path` | The file path or `"-"` to write to stdout. |
| `data` | - | The JSON-serializable data to output. |
| `indent` | int | Number of spaces used to indent JSON. Defaults to `2`. |
#### function `srsly.read_json`
Load JSON from a file or standard input.
```python
data = srsly.read_json("/path/to/file.json")
```
| Argument | Type | Description |
| ----------- | ------------ | ------------------------------------------ |
| `path` | str / `Path` | The file path or `"-"` to read from stdin. |
| **RETURNS** | dict / list | The loaded JSON content. |
#### function `srsly.write_gzip_json`
Create a gzipped JSON file and dump contents.
```python
data = {"foo": "bar", "baz": 123}
srsly.write_gzip_json("/path/to/file.json.gz", data)
```
| Argument | Type | Description |
| -------- | ------------ | ------------------------------------------------------ |
| `path` | str / `Path` | The file path. |
| `data` | - | The JSON-serializable data to output. |
| `indent` | int | Number of spaces used to indent JSON. Defaults to `2`. |
#### function `srsly.write_gzip_jsonl`
Create a gzipped JSONL file and dump contents.
```python
data = [{"foo": "bar"}, {"baz": 123}]
srsly.write_gzip_json("/path/to/file.jsonl.gz", data)
```
| Argument | Type | Description |
| ----------------- | ------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `path` | str / `Path` | The file path. |
| `lines` | - | The JSON-serializable contents of each line. |
| `append` | bool | Whether or not to append to the location. Appending to .gz files is generally not recommended, as it doesn't allow the algorithm to take advantage of all data when compressing - files may hence be poorly compressed. |
| `append_new_line` | bool | Whether or not to write a new line before appending to the file. |
#### function `srsly.read_gzip_json`
Load gzipped JSON from a file.
```python
data = srsly.read_gzip_json("/path/to/file.json.gz")
```
| Argument | Type | Description |
| ----------- | ------------ | ------------------------ |
| `path` | str / `Path` | The file path. |
| **RETURNS** | dict / list | The loaded JSON content. |
#### function `srsly.read_gzip_jsonl`
Load gzipped JSONL from a file.
```python
data = srsly.read_gzip_jsonl("/path/to/file.jsonl.gz")
```
| Argument | Type | Description |
| ----------- | ------------ | ------------------------- |
| `path` | str / `Path` | The file path. |
| **RETURNS** | dict / list | The loaded JSONL content. |
#### function `srsly.write_jsonl`
Create a JSONL file (newline-delimited JSON) and dump contents line by line, or
write to standard output.
```python
data = [{"foo": "bar"}, {"baz": 123}]
srsly.write_jsonl("/path/to/file.jsonl", data)
```
| Argument | Type | Description |
| ----------------- | ------------ | ---------------------------------------------------------------------------------------------------------------------- |
| `path` | str / `Path` | The file path or `"-"` to write to stdout. |
| `lines` | iterable | The JSON-serializable lines. |
| `append` | bool | Append to an existing file. Will open it in `"a"` mode and insert a newline before writing lines. Defaults to `False`. |
| `append_new_line` | bool | Defines whether a new line should first be written when appending to an existing file. Defaults to `True`. |
#### function `srsly.read_jsonl`
Read a JSONL file (newline-delimited JSON) or from JSONL data from standard
input and yield contents line by line. Blank lines will always be skipped.
```python
data = srsly.read_jsonl("/path/to/file.jsonl")
```
| Argument | Type | Description |
| ---------- | ---------- | -------------------------------------------------------------------- |
| `path` | str / Path | The file path or `"-"` to read from stdin. |
| `skip` | bool | Skip broken lines and don't raise `ValueError`. Defaults to `False`. |
| **YIELDS** | - | The loaded JSON contents of each line. |
#### function `srsly.is_json_serializable`
Check if a Python object is JSON-serializable.
```python
assert srsly.is_json_serializable({"hello": "world"}) is True
assert srsly.is_json_serializable(lambda x: x) is False
```
| Argument | Type | Description |
| ----------- | ---- | ---------------------------------------- |
| `obj` | - | The object to check. |
| **RETURNS** | bool | Whether the object is JSON-serializable. |
### msgpack
> 📦 The underlying module is exposed via `srsly.msgpack`. However, we normally
> interact with it via the utility functions only.
#### function `srsly.msgpack_dumps`
Serialize an object to a msgpack byte string.
```python
data = {"foo": "bar", "baz": 123}
msg = srsly.msgpack_dumps(data)
```
| Argument | Type | Description |
| ----------- | ----- | ---------------------- |
| `data` | - | The data to serialize. |
| **RETURNS** | bytes | The serialized bytes. |
#### function `srsly.msgpack_loads`
Deserialize msgpack bytes to a Python object.
```python
msg = b"\x82\xa3foo\xa3bar\xa3baz{"
data = srsly.msgpack_loads(msg)
```
| Argument | Type | Description |
| ----------- | ----- | --------------------------------------------------------------------------------------- |
| `data` | bytes | The data to deserialize. |
| `use_list` | bool | Don't use tuples instead of lists. Can make deserialization slower. Defaults to `True`. |
| **RETURNS** | - | The deserialized Python object. |
#### function `srsly.write_msgpack`
Create a msgpack file and dump contents.
```python
data = {"foo": "bar", "baz": 123}
srsly.write_msgpack("/path/to/file.msg", data)
```
| Argument | Type | Description |
| -------- | ------------ | ---------------------- |
| `path` | str / `Path` | The file path. |
| `data` | - | The data to serialize. |
#### function `srsly.read_msgpack`
Load a msgpack file.
```python
data = srsly.read_msgpack("/path/to/file.msg")
```
| Argument | Type | Description |
| ----------- | ------------ | --------------------------------------------------------------------------------------- |
| `path` | str / `Path` | The file path. |
| `use_list` | bool | Don't use tuples instead of lists. Can make deserialization slower. Defaults to `True`. |
| **RETURNS** | - | The loaded and deserialized content. |
### pickle
> 📦 The underlying module is exposed via `srsly.cloudpickle`. However, we
> normally interact with it via the utility functions only.
#### function `srsly.pickle_dumps`
Serialize a Python object with pickle.
```python
data = {"foo": "bar", "baz": 123}
pickled_data = srsly.pickle_dumps(data)
```
| Argument | Type | Description |
| ----------- | ----- | ------------------------------------------------------ |
| `data` | - | The object to serialize. |
| `protocol` | int | Protocol to use. `-1` for highest. Defaults to `None`. |
| **RETURNS** | bytes | The serialized object. |
#### function `srsly.pickle_loads`
Deserialize bytes with pickle.
```python
pickled_data = b"\x80\x04\x95\x19\x00\x00\x00\x00\x00\x00\x00}\x94(\x8c\x03foo\x94\x8c\x03bar\x94\x8c\x03baz\x94K{u."
data = srsly.pickle_loads(pickled_data)
```
| Argument | Type | Description |
| ----------- | ----- | ------------------------------- |
| `data` | bytes | The data to deserialize. |
| **RETURNS** | - | The deserialized Python object. |
### YAML
> 📦 The underlying module is exposed via `srsly.ruamel_yaml`. However, we
> normally interact with it via the utility functions only.
#### function `srsly.yaml_dumps`
Serialize an object to a YAML string. See the
[`ruamel.yaml` docs](https://yaml.readthedocs.io/en/latest/detail.html?highlight=indentation#indentation-of-block-sequences)
for details on the indentation format.
```python
data = {"foo": "bar", "baz": 123}
yaml_string = srsly.yaml_dumps(data)
```
| Argument | Type | Description |
| ----------------- | ---- | ------------------------------------------ |
| `data` | - | The JSON-serializable data to output. |
| `indent_mapping` | int | Mapping indentation. Defaults to `2`. |
| `indent_sequence` | int | Sequence indentation. Defaults to `4`. |
| `indent_offset` | int | Indentation offset. Defaults to `2`. |
| `sort_keys` | bool | Sort dictionary keys. Defaults to `False`. |
| **RETURNS** | str | The serialized string. |
#### function `srsly.yaml_loads`
Deserialize unicode or a file object to a Python object.
```python
data = 'foo: bar\nbaz: 123'
obj = srsly.yaml_loads(data)
```
| Argument | Type | Description |
| ----------- | ---------- | ------------------------------- |
| `data` | str / file | The data to deserialize. |
| **RETURNS** | - | The deserialized Python object. |
#### function `srsly.write_yaml`
Create a YAML file and dump contents or write to standard output.
```python
data = {"foo": "bar", "baz": 123}
srsly.write_yaml("/path/to/file.yml", data)
```
| Argument | Type | Description |
| ----------------- | ------------ | ------------------------------------------ |
| `path` | str / `Path` | The file path or `"-"` to write to stdout. |
| `data` | - | The JSON-serializable data to output. |
| `indent_mapping` | int | Mapping indentation. Defaults to `2`. |
| `indent_sequence` | int | Sequence indentation. Defaults to `4`. |
| `indent_offset` | int | Indentation offset. Defaults to `2`. |
| `sort_keys` | bool | Sort dictionary keys. Defaults to `False`. |
#### function `srsly.read_yaml`
Load YAML from a file or standard input.
```python
data = srsly.read_yaml("/path/to/file.yml")
```
| Argument | Type | Description |
| ----------- | ------------ | ------------------------------------------ |
| `path` | str / `Path` | The file path or `"-"` to read from stdin. |
| **RETURNS** | dict / list | The loaded YAML content. |
#### function `srsly.is_yaml_serializable`
Check if a Python object is YAML-serializable.
```python
assert srsly.is_yaml_serializable({"hello": "world"}) is True
assert srsly.is_yaml_serializable(lambda x: x) is False
```
| Argument | Type | Description |
| ----------- | ---- | ---------------------------------------- |
| `obj` | - | The object to check. |
| **RETURNS** | bool | Whether the object is YAML-serializable. |
srsly-release-v2.5.1/bin/ 0000775 0000000 0000000 00000000000 14742310675 0015260 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/bin/push-tag.sh 0000775 0000000 0000000 00000000532 14742310675 0017347 0 ustar 00root root 0000000 0000000 #!/usr/bin/env bash
set -e
# Insist repository is clean
git diff-index --quiet HEAD
git checkout $1
git pull origin $1
version=$(grep "__version__ = " srsly/about.py)
version=${version/__version__ = }
version=${version/\'/}
version=${version/\'/}
version=${version/\"/}
version=${version/\"/}
git tag "v$version"
git push origin "v$version"
srsly-release-v2.5.1/pyproject.toml 0000664 0000000 0000000 00000002523 14742310675 0017426 0 ustar 00root root 0000000 0000000 [build-system]
requires = [
"setuptools",
"cython>=0.25",
]
build-backend = "setuptools.build_meta"
[tool.cibuildwheel]
build = "*"
skip = "pp* cp36* cp37* cp38*"
test-skip = ""
free-threaded-support = false
archs = ["native"]
build-frontend = "default"
config-settings = {}
dependency-versions = "pinned"
environment = {}
environment-pass = []
build-verbosity = 0
before-all = ""
before-build = ""
repair-wheel-command = ""
test-command = ""
before-test = ""
test-requires = []
test-extras = []
container-engine = "docker"
manylinux-x86_64-image = "manylinux2014"
manylinux-i686-image = "manylinux2014"
manylinux-aarch64-image = "manylinux2014"
manylinux-ppc64le-image = "manylinux2014"
manylinux-s390x-image = "manylinux2014"
manylinux-pypy_x86_64-image = "manylinux2014"
manylinux-pypy_i686-image = "manylinux2014"
manylinux-pypy_aarch64-image = "manylinux2014"
musllinux-x86_64-image = "musllinux_1_2"
musllinux-i686-image = "musllinux_1_2"
musllinux-aarch64-image = "musllinux_1_2"
musllinux-ppc64le-image = "musllinux_1_2"
musllinux-s390x-image = "musllinux_1_2"
[tool.cibuildwheel.linux]
repair-wheel-command = "auditwheel repair -w {dest_dir} {wheel}"
[tool.cibuildwheel.macos]
repair-wheel-command = "delocate-wheel --require-archs {delocate_archs} -w {dest_dir} -v {wheel}"
[tool.cibuildwheel.windows]
[tool.cibuildwheel.pyodide]
srsly-release-v2.5.1/requirements.txt 0000664 0000000 0000000 00000000216 14742310675 0017773 0 ustar 00root root 0000000 0000000 catalogue>=2.0.3,<2.1.0
# Development requirements
cython>=0.29.1
pytest>=4.6.5
pytest-timeout>=1.3.3
mock>=2.0.0,<3.0.0
numpy>=1.15.0
psutil
srsly-release-v2.5.1/setup.cfg 0000664 0000000 0000000 00000003345 14742310675 0016336 0 ustar 00root root 0000000 0000000 [metadata]
description = Modern high-performance serialization utilities for Python
url = https://github.com/explosion/srsly
author = Explosion
author_email = contact@explosion.ai
license = MIT
long_description = file: README.md
long_description_content_type = text/markdown
classifiers =
Development Status :: 5 - Production/Stable
Environment :: Console
Intended Audience :: Developers
Intended Audience :: Science/Research
License :: OSI Approved :: MIT License
Operating System :: POSIX :: Linux
Operating System :: MacOS :: MacOS X
Operating System :: Microsoft :: Windows
Programming Language :: Cython
Programming Language :: Python :: 3
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Programming Language :: Python :: 3.13
Topic :: Scientific/Engineering
[options]
zip_safe = true
include_package_data = true
python_requires = >=3.9,<3.14
setup_requires =
cython>=0.29.1
install_requires =
catalogue>=2.0.3,<2.1.0
[options.entry_points]
# If spaCy is installed in the same environment as srsly, it will automatically
# have these readers available
spacy_readers =
srsly.read_json.v1 = srsly:read_json
srsly.read_jsonl.v1 = srsly:read_jsonl
srsly.read_yaml.v1 = srsly:read_yaml
srsly.read_msgpack.v1 = srsly:read_msgpack
[bdist_wheel]
universal = false
[sdist]
formats = gztar
[flake8]
ignore = E203, E266, E501, E731, W503, E741
max-line-length = 80
select = B,C,E,F,W,T4,B9
exclude =
srsly/__init__.py
srsly/msgpack/__init__.py
srsly/cloudpickle/__init__.py
[mypy]
ignore_missing_imports = True
[mypy-srsly.cloudpickle.*]
ignore_errors=True
srsly-release-v2.5.1/setup.py 0000664 0000000 0000000 00000010573 14742310675 0016230 0 ustar 00root root 0000000 0000000 #!/usr/bin/env python
import sys
from setuptools.command.build_ext import build_ext
from sysconfig import get_path
from setuptools import Extension, setup, find_packages
from pathlib import Path
from Cython.Build import cythonize
from Cython.Compiler import Options
import contextlib
import os
# Preserve `__doc__` on functions and classes
# http://docs.cython.org/en/latest/src/userguide/source_files_and_compilation.html#compiler-options
Options.docstrings = True
PACKAGE_DATA = {"": ["*.pyx", "*.pxd", "*.c", "*.h", "*.cpp"]}
PACKAGES = find_packages()
# msgpack has this whacky build where it only builds _cmsgpack which textually includes
# _packer and _unpacker. I refactored this.
MOD_NAMES = ["srsly.msgpack._epoch", "srsly.msgpack._packer", "srsly.msgpack._unpacker"]
COMPILE_OPTIONS = {
"msvc": ["/Ox", "/EHsc"],
"mingw32": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"],
"other": ["-O2", "-Wno-strict-prototypes", "-Wno-unused-function"],
}
COMPILER_DIRECTIVES = {
"language_level": -3,
"embedsignature": True,
"annotation_typing": False,
}
LINK_OPTIONS = {"msvc": [], "mingw32": [], "other": ["-lstdc++", "-lm"]}
if sys.byteorder == "big":
macros = [("__BIG_ENDIAN__", "1")]
else:
macros = [("__LITTLE_ENDIAN__", "1")]
# By subclassing build_extensions we have the actual compiler that will be used
# which is really known only after finalize_options
# http://stackoverflow.com/questions/724664/python-distutils-how-to-get-a-compiler-that-is-going-to-be-used
class build_ext_options:
def build_options(self):
if hasattr(self.compiler, "initialize"):
self.compiler.initialize()
self.compiler.platform = sys.platform[:6]
for e in self.extensions:
e.extra_compile_args += COMPILE_OPTIONS.get(
self.compiler.compiler_type, COMPILE_OPTIONS["other"]
)
e.extra_link_args += LINK_OPTIONS.get(
self.compiler.compiler_type, LINK_OPTIONS["other"]
)
class build_ext_subclass(build_ext, build_ext_options):
def build_extensions(self):
build_ext_options.build_options(self)
build_ext.build_extensions(self)
def clean(path):
n_cleaned = 0
for name in MOD_NAMES:
name = name.replace(".", "/")
for ext in ["so", "html", "cpp", "c"]:
file_path = path / f"{name}.{ext}"
if file_path.exists():
file_path.unlink()
n_cleaned += 1
print(f"Cleaned {n_cleaned} files")
@contextlib.contextmanager
def chdir(new_dir):
old_dir = os.getcwd()
try:
os.chdir(new_dir)
sys.path.insert(0, new_dir)
yield
finally:
del sys.path[0]
os.chdir(old_dir)
def setup_package():
root = Path(__file__).parent
if len(sys.argv) > 1 and sys.argv[1] == "clean":
return clean(root)
with (root / "srsly" / "about.py").open("r") as f:
about = {}
exec(f.read(), about)
with chdir(str(root)):
include_dirs = [get_path("include"), ".", "srsly"]
ext_modules = []
for name in MOD_NAMES:
mod_path = name.replace(".", "/") + ".pyx"
ext_modules.append(
Extension(
name,
[mod_path],
language="c++",
include_dirs=include_dirs,
define_macros=macros,
)
)
ext_modules.append(
Extension(
"srsly.ujson.ujson",
sources=[
"./srsly/ujson/ujson.c",
"./srsly/ujson/objToJSON.c",
"./srsly/ujson/JSONtoObj.c",
"./srsly/ujson/lib/ultrajsonenc.c",
"./srsly/ujson/lib/ultrajsondec.c",
],
include_dirs=["./srsly/ujson", "./srsly/ujson/lib"],
extra_compile_args=["-D_GNU_SOURCE"],
)
)
print("Cythonizing sources")
ext_modules = cythonize(
ext_modules, compiler_directives=COMPILER_DIRECTIVES, language_level=2
)
setup(
name="srsly",
packages=PACKAGES,
version=about["__version__"],
ext_modules=ext_modules,
cmdclass={"build_ext": build_ext_subclass},
package_data=PACKAGE_DATA,
)
if __name__ == "__main__":
setup_package()
srsly-release-v2.5.1/srsly/ 0000775 0000000 0000000 00000000000 14742310675 0015664 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/__init__.py 0000664 0000000 0000000 00000001117 14742310675 0017775 0 ustar 00root root 0000000 0000000 from ._json_api import read_json, read_gzip_json, write_json, write_gzip_json
from ._json_api import read_gzip_jsonl, write_gzip_jsonl
from ._json_api import read_jsonl, write_jsonl
from ._json_api import json_dumps, json_loads, is_json_serializable
from ._msgpack_api import read_msgpack, write_msgpack, msgpack_dumps, msgpack_loads
from ._msgpack_api import msgpack_encoders, msgpack_decoders
from ._pickle_api import pickle_dumps, pickle_loads
from ._yaml_api import read_yaml, write_yaml, yaml_dumps, yaml_loads
from ._yaml_api import is_yaml_serializable
from .about import __version__
srsly-release-v2.5.1/srsly/_json_api.py 0000664 0000000 0000000 00000015634 14742310675 0020210 0 ustar 00root root 0000000 0000000 from typing import Union, Iterable, Sequence, Any, Optional, Iterator
import sys
import json as _builtin_json
import gzip
from . import ujson
from .util import force_path, force_string, FilePath, JSONInput, JSONOutput
def json_dumps(
data: JSONInput, indent: Optional[int] = 0, sort_keys: bool = False
) -> str:
"""Serialize an object to a JSON string.
data: The JSON-serializable data.
indent (int): Number of spaces used to indent JSON.
sort_keys (bool): Sort dictionary keys. Falls back to json module for now.
RETURNS (str): The serialized string.
"""
if sort_keys:
indent = None if indent == 0 else indent
result = _builtin_json.dumps(
data, indent=indent, separators=(",", ":"), sort_keys=sort_keys
)
else:
result = ujson.dumps(data, indent=indent, escape_forward_slashes=False)
return result
def json_loads(data: Union[str, bytes]) -> JSONOutput:
"""Deserialize unicode or bytes to a Python object.
data (str / bytes): The data to deserialize.
RETURNS: The deserialized Python object.
"""
# Avoid transforming the string '-' into the int '0'
if data == "-":
raise ValueError("Expected object or value")
return ujson.loads(data)
def read_json(path: FilePath) -> JSONOutput:
"""Load JSON from file or standard input.
path (FilePath): The file path. "-" for reading from stdin.
RETURNS (JSONOutput): The loaded JSON content.
"""
if path == "-": # reading from sys.stdin
data = sys.stdin.read()
return ujson.loads(data)
file_path = force_path(path)
with file_path.open("r", encoding="utf8") as f:
return ujson.load(f)
def read_gzip_json(path: FilePath) -> JSONOutput:
"""Load JSON from a gzipped file.
location (FilePath): The file path.
RETURNS (JSONOutput): The loaded JSON content.
"""
file_path = force_string(path)
with gzip.open(file_path, "r") as f:
return ujson.load(f)
def read_gzip_jsonl(path: FilePath, skip: bool = False) -> Iterator[JSONOutput]:
"""Read a gzipped .jsonl file and yield contents line by line.
Blank lines will always be skipped.
path (FilePath): The file path.
skip (bool): Skip broken lines and don't raise ValueError.
YIELDS (JSONOutput): The unpacked, deserialized Python objects.
"""
with gzip.open(force_path(path), "r") as f:
for line in _yield_json_lines(f, skip=skip):
yield line
def write_json(path: FilePath, data: JSONInput, indent: int = 2) -> None:
"""Create a .json file and dump contents or write to standard
output.
location (FilePath): The file path. "-" for writing to stdout.
data (JSONInput): The JSON-serializable data to output.
indent (int): Number of spaces used to indent JSON.
"""
json_data = json_dumps(data, indent=indent)
if path == "-": # writing to stdout
print(json_data)
else:
file_path = force_path(path, require_exists=False)
with file_path.open("w", encoding="utf8") as f:
f.write(json_data)
def write_gzip_json(path: FilePath, data: JSONInput, indent: int = 2) -> None:
"""Create a .json.gz file and dump contents.
path (FilePath): The file path.
data (JSONInput): The JSON-serializable data to output.
indent (int): Number of spaces used to indent JSON.
"""
json_data = json_dumps(data, indent=indent)
file_path = force_string(path)
with gzip.open(file_path, "w") as f:
f.write(json_data.encode("utf-8"))
def write_gzip_jsonl(
path: FilePath,
lines: Iterable[JSONInput],
append: bool = False,
append_new_line: bool = True,
) -> None:
"""Create a .jsonl.gz file and dump contents.
location (FilePath): The file path.
lines (Sequence[JSONInput]): The JSON-serializable contents of each line.
append (bool): Whether or not to append to the location. Appending to .gz files is generally not recommended, as it
doesn't allow the algorithm to take advantage of all data when compressing - files may hence be poorly
compressed.
append_new_line (bool): Whether or not to write a new line before appending
to the file.
"""
mode = "a" if append else "w"
file_path = force_path(path, require_exists=False)
with gzip.open(file_path, mode=mode) as f:
if append and append_new_line:
f.write("\n".encode("utf-8"))
f.writelines([(json_dumps(line) + "\n").encode("utf-8") for line in lines])
def read_jsonl(path: FilePath, skip: bool = False) -> Iterable[JSONOutput]:
"""Read a .jsonl file or standard input and yield contents line by line.
Blank lines will always be skipped.
path (FilePath): The file path. "-" for reading from stdin.
skip (bool): Skip broken lines and don't raise ValueError.
YIELDS (JSONOutput): The loaded JSON contents of each line.
"""
if path == "-": # reading from sys.stdin
for line in _yield_json_lines(sys.stdin, skip=skip):
yield line
else:
file_path = force_path(path)
with file_path.open("r", encoding="utf8") as f:
for line in _yield_json_lines(f, skip=skip):
yield line
def write_jsonl(
path: FilePath,
lines: Iterable[JSONInput],
append: bool = False,
append_new_line: bool = True,
) -> None:
"""Create a .jsonl file and dump contents or write to standard output.
location (FilePath): The file path. "-" for writing to stdout.
lines (Sequence[JSONInput]): The JSON-serializable contents of each line.
append (bool): Whether or not to append to the location.
append_new_line (bool): Whether or not to write a new line before appending
to the file.
"""
if path == "-": # writing to stdout
for line in lines:
print(json_dumps(line))
else:
mode = "a" if append else "w"
file_path = force_path(path, require_exists=False)
with file_path.open(mode, encoding="utf-8") as f:
if append and append_new_line:
f.write("\n")
for line in lines:
f.write(json_dumps(line) + "\n")
def is_json_serializable(obj: Any) -> bool:
"""Check if a Python object is JSON-serializable.
obj: The object to check.
RETURNS (bool): Whether the object is JSON-serializable.
"""
if hasattr(obj, "__call__"):
# Check this separately here to prevent infinite recursions
return False
try:
ujson.dumps(obj)
return True
except (TypeError, OverflowError):
return False
def _yield_json_lines(
stream: Iterable[str], skip: bool = False
) -> Iterable[JSONOutput]:
line_no = 1
for line in stream:
line = line.strip()
if line == "":
continue
try:
yield ujson.loads(line)
except ValueError:
if skip:
continue
raise ValueError(f"Invalid JSON on line {line_no}: {line}")
line_no += 1
srsly-release-v2.5.1/srsly/_msgpack_api.py 0000664 0000000 0000000 00000003477 14742310675 0020666 0 ustar 00root root 0000000 0000000 import gc
from . import msgpack
from .msgpack import msgpack_encoders, msgpack_decoders # noqa: F401
from .util import force_path, FilePath, JSONInputBin, JSONOutputBin
def msgpack_dumps(data: JSONInputBin) -> bytes:
"""Serialize an object to a msgpack byte string.
data: The data to serialize.
RETURNS (bytes): The serialized bytes.
"""
return msgpack.dumps(data, use_bin_type=True)
def msgpack_loads(data: bytes, use_list: bool = True) -> JSONOutputBin:
"""Deserialize msgpack bytes to a Python object.
data (bytes): The data to deserialize.
use_list (bool): Don't use tuples instead of lists. Can make
deserialization slower.
RETURNS: The deserialized Python object.
"""
# msgpack-python docs suggest disabling gc before unpacking large messages
gc.disable()
msg = msgpack.loads(data, raw=False, use_list=use_list)
gc.enable()
return msg
def write_msgpack(path: FilePath, data: JSONInputBin) -> None:
"""Create a msgpack file and dump contents.
location (FilePath): The file path.
data (JSONInputBin): The data to serialize.
"""
file_path = force_path(path, require_exists=False)
with file_path.open("wb") as f:
msgpack.dump(data, f, use_bin_type=True)
def read_msgpack(path: FilePath, use_list: bool = True) -> JSONOutputBin:
"""Load a msgpack file.
location (FilePath): The file path.
use_list (bool): Don't use tuples instead of lists. Can make
deserialization slower.
RETURNS (JSONOutputBin): The loaded and deserialized content.
"""
file_path = force_path(path)
with file_path.open("rb") as f:
# msgpack-python docs suggest disabling gc before unpacking large messages
gc.disable()
msg = msgpack.load(f, raw=False, use_list=use_list)
gc.enable()
return msg
srsly-release-v2.5.1/srsly/_pickle_api.py 0000664 0000000 0000000 00000001167 14742310675 0020502 0 ustar 00root root 0000000 0000000 from typing import Optional
from . import cloudpickle
from .util import JSONInput, JSONOutput
def pickle_dumps(data: JSONInput, protocol: Optional[int] = None) -> bytes:
"""Serialize a Python object with pickle.
data: The object to serialize.
protocol (int): Protocol to use. -1 for highest.
RETURNS (bytes): The serialized object.
"""
return cloudpickle.dumps(data, protocol=protocol)
def pickle_loads(data: bytes) -> JSONOutput:
"""Deserialize bytes with pickle.
data (bytes): The data to deserialize.
RETURNS: The deserialized Python object.
"""
return cloudpickle.loads(data)
srsly-release-v2.5.1/srsly/_yaml_api.py 0000664 0000000 0000000 00000007275 14742310675 0020203 0 ustar 00root root 0000000 0000000 from typing import Union, IO, Any
from io import StringIO
import sys
from .ruamel_yaml import YAML
from .ruamel_yaml.representer import RepresenterError
from .util import force_path, FilePath, YAMLInput, YAMLOutput
class CustomYaml(YAML):
def __init__(self, typ="safe", pure=True):
YAML.__init__(self, typ=typ, pure=pure)
self.default_flow_style = False
self.allow_unicode = True
self.encoding = "utf-8"
# https://yaml.readthedocs.io/en/latest/example.html#output-of-dump-as-a-string
def dump(self, data, stream=None, **kw):
inefficient = False
if stream is None:
inefficient = True
stream = StringIO()
YAML.dump(self, data, stream, **kw)
if inefficient:
return stream.getvalue()
def yaml_dumps(
data: YAMLInput,
indent_mapping: int = 2,
indent_sequence: int = 4,
indent_offset: int = 2,
sort_keys: bool = False,
) -> str:
"""Serialize an object to a YAML string. See the ruamel.yaml docs on
indentation for more details on the expected format.
https://yaml.readthedocs.io/en/latest/detail.html?highlight=indentation#indentation-of-block-sequences
data: The YAML-serializable data.
indent_mapping (int): Mapping indentation.
indent_sequence (int): Sequence indentation.
indent_offset (int): Indentation offset.
sort_keys (bool): Sort dictionary keys.
RETURNS (str): The serialized string.
"""
yaml = CustomYaml()
yaml.sort_base_mapping_type_on_output = sort_keys
yaml.indent(mapping=indent_mapping, sequence=indent_sequence, offset=indent_offset)
return yaml.dump(data)
def yaml_loads(data: Union[str, IO]) -> YAMLOutput:
"""Deserialize unicode or a file object a Python object.
data (str / file): The data to deserialize.
RETURNS: The deserialized Python object.
"""
yaml = CustomYaml()
try:
return yaml.load(data)
except Exception as e:
raise ValueError(f"Invalid YAML: {e}")
def read_yaml(path: FilePath) -> YAMLOutput:
"""Load YAML from file or standard input.
location (FilePath): The file path. "-" for reading from stdin.
RETURNS (YAMLOutput): The loaded content.
"""
if path == "-": # reading from sys.stdin
data = sys.stdin.read()
return yaml_loads(data)
file_path = force_path(path)
with file_path.open("r", encoding="utf8") as f:
return yaml_loads(f)
def write_yaml(
path: FilePath,
data: YAMLInput,
indent_mapping: int = 2,
indent_sequence: int = 4,
indent_offset: int = 2,
sort_keys: bool = False,
) -> None:
"""Create a .json file and dump contents or write to standard
output.
location (FilePath): The file path. "-" for writing to stdout.
data (YAMLInput): The JSON-serializable data to output.
indent_mapping (int): Mapping indentation.
indent_sequence (int): Sequence indentation.
indent_offset (int): Indentation offset.
sort_keys (bool): Sort dictionary keys.
"""
yaml_data = yaml_dumps(
data,
indent_mapping=indent_mapping,
indent_sequence=indent_sequence,
indent_offset=indent_offset,
sort_keys=sort_keys,
)
if path == "-": # writing to stdout
print(yaml_data)
else:
file_path = force_path(path, require_exists=False)
with file_path.open("w", encoding="utf8") as f:
f.write(yaml_data)
def is_yaml_serializable(obj: Any) -> bool:
"""Check if a Python object is YAML-serializable (strict).
obj: The object to check.
RETURNS (bool): Whether the object is YAML-serializable.
"""
try:
yaml_dumps(obj)
return True
except RepresenterError:
return False
srsly-release-v2.5.1/srsly/about.py 0000664 0000000 0000000 00000000026 14742310675 0017346 0 ustar 00root root 0000000 0000000 __version__ = "2.5.1"
srsly-release-v2.5.1/srsly/cloudpickle/ 0000775 0000000 0000000 00000000000 14742310675 0020162 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/cloudpickle/__init__.py 0000664 0000000 0000000 00000000444 14742310675 0022275 0 ustar 00root root 0000000 0000000 from .cloudpickle import * # noqa
from .cloudpickle_fast import CloudPickler, dumps, dump # noqa
# Conform to the convention used by python serialization libraries, which
# expose their Pickler subclass at top-level under the "Pickler" name.
Pickler = CloudPickler
__version__ = '2.2.0'
srsly-release-v2.5.1/srsly/cloudpickle/cloudpickle.py 0000664 0000000 0000000 00000104501 14742310675 0023033 0 ustar 00root root 0000000 0000000 """
This class is defined to override standard pickle functionality
The goals of it follow:
-Serialize lambdas and nested functions to compiled byte code
-Deal with main module correctly
-Deal with other non-serializable objects
It does not include an unpickler, as standard python unpickling suffices.
This module was extracted from the `cloud` package, developed by `PiCloud, Inc.
`_.
Copyright (c) 2012, Regents of the University of California.
Copyright (c) 2009 `PiCloud, Inc. `_.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the University of California, Berkeley nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import builtins
import dis
import opcode
import platform
import sys
import types
import weakref
import uuid
import threading
import typing
import warnings
from .compat import pickle
from collections import OrderedDict
from typing import ClassVar, Generic, Union, Tuple, Callable
from pickle import _getattribute
from importlib._bootstrap import _find_spec
try: # pragma: no branch
import typing_extensions as _typing_extensions
from typing_extensions import Literal, Final
except ImportError:
_typing_extensions = Literal = Final = None
if sys.version_info >= (3, 8):
from types import CellType
else:
def f():
a = 1
def g():
return a
return g
CellType = type(f().__closure__[0])
# cloudpickle is meant for inter process communication: we expect all
# communicating processes to run the same Python version hence we favor
# communication speed over compatibility:
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
# Names of modules whose resources should be treated as dynamic.
_PICKLE_BY_VALUE_MODULES = set()
# Track the provenance of reconstructed dynamic classes to make it possible to
# reconstruct instances from the matching singleton class definition when
# appropriate and preserve the usual "isinstance" semantics of Python objects.
_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary()
_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary()
_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock()
PYPY = platform.python_implementation() == "PyPy"
builtin_code_type = None
if PYPY:
# builtin-code objects only exist in pypy
builtin_code_type = type(float.__new__.__code__)
_extract_code_globals_cache = weakref.WeakKeyDictionary()
def _get_or_create_tracker_id(class_def):
with _DYNAMIC_CLASS_TRACKER_LOCK:
class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def)
if class_tracker_id is None:
class_tracker_id = uuid.uuid4().hex
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
_DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def
return class_tracker_id
def _lookup_class_or_track(class_tracker_id, class_def):
if class_tracker_id is not None:
with _DYNAMIC_CLASS_TRACKER_LOCK:
class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault(
class_tracker_id, class_def)
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
return class_def
def register_pickle_by_value(module):
"""Register a module to make it functions and classes picklable by value.
By default, functions and classes that are attributes of an importable
module are to be pickled by reference, that is relying on re-importing
the attribute from the module at load time.
If `register_pickle_by_value(module)` is called, all its functions and
classes are subsequently to be pickled by value, meaning that they can
be loaded in Python processes where the module is not importable.
This is especially useful when developing a module in a distributed
execution environment: restarting the client Python process with the new
source code is enough: there is no need to re-install the new version
of the module on all the worker nodes nor to restart the workers.
Note: this feature is considered experimental. See the cloudpickle
README.md file for more details and limitations.
"""
if not isinstance(module, types.ModuleType):
raise ValueError(
f"Input should be a module object, got {str(module)} instead"
)
# In the future, cloudpickle may need a way to access any module registered
# for pickling by value in order to introspect relative imports inside
# functions pickled by value. (see
# https://github.com/cloudpipe/cloudpickle/pull/417#issuecomment-873684633).
# This access can be ensured by checking that module is present in
# sys.modules at registering time and assuming that it will still be in
# there when accessed during pickling. Another alternative would be to
# store a weakref to the module. Even though cloudpickle does not implement
# this introspection yet, in order to avoid a possible breaking change
# later, we still enforce the presence of module inside sys.modules.
if module.__name__ not in sys.modules:
raise ValueError(
f"{module} was not imported correctly, have you used an "
f"`import` statement to access it?"
)
_PICKLE_BY_VALUE_MODULES.add(module.__name__)
def unregister_pickle_by_value(module):
"""Unregister that the input module should be pickled by value."""
if not isinstance(module, types.ModuleType):
raise ValueError(
f"Input should be a module object, got {str(module)} instead"
)
if module.__name__ not in _PICKLE_BY_VALUE_MODULES:
raise ValueError(f"{module} is not registered for pickle by value")
else:
_PICKLE_BY_VALUE_MODULES.remove(module.__name__)
def list_registry_pickle_by_value():
return _PICKLE_BY_VALUE_MODULES.copy()
def _is_registered_pickle_by_value(module):
module_name = module.__name__
if module_name in _PICKLE_BY_VALUE_MODULES:
return True
while True:
parent_name = module_name.rsplit(".", 1)[0]
if parent_name == module_name:
break
if parent_name in _PICKLE_BY_VALUE_MODULES:
return True
module_name = parent_name
return False
def _whichmodule(obj, name):
"""Find the module an object belongs to.
This function differs from ``pickle.whichmodule`` in two ways:
- it does not mangle the cases where obj's module is __main__ and obj was
not found in any module.
- Errors arising during module introspection are ignored, as those errors
are considered unwanted side effects.
"""
if sys.version_info[:2] < (3, 7) and isinstance(obj, typing.TypeVar): # pragma: no branch # noqa
# Workaround bug in old Python versions: prior to Python 3.7,
# T.__module__ would always be set to "typing" even when the TypeVar T
# would be defined in a different module.
if name is not None and getattr(typing, name, None) is obj:
# Built-in TypeVar defined in typing such as AnyStr
return 'typing'
else:
# User defined or third-party TypeVar: __module__ attribute is
# irrelevant, thus trigger a exhaustive search for obj in all
# modules.
module_name = None
else:
module_name = getattr(obj, '__module__', None)
if module_name is not None:
return module_name
# Protect the iteration by using a copy of sys.modules against dynamic
# modules that trigger imports of other modules upon calls to getattr or
# other threads importing at the same time.
for module_name, module in sys.modules.copy().items():
# Some modules such as coverage can inject non-module objects inside
# sys.modules
if (
module_name == '__main__' or
module is None or
not isinstance(module, types.ModuleType)
):
continue
try:
if _getattribute(module, name)[0] is obj:
return module_name
except Exception:
pass
return None
def _should_pickle_by_reference(obj, name=None):
"""Test whether an function or a class should be pickled by reference
Pickling by reference means by that the object (typically a function or a
class) is an attribute of a module that is assumed to be importable in the
target Python environment. Loading will therefore rely on importing the
module and then calling `getattr` on it to access the function or class.
Pickling by reference is the only option to pickle functions and classes
in the standard library. In cloudpickle the alternative option is to
pickle by value (for instance for interactively or locally defined
functions and classes or for attributes of modules that have been
explicitly registered to be pickled by value.
"""
if isinstance(obj, types.FunctionType) or issubclass(type(obj), type):
module_and_name = _lookup_module_and_qualname(obj, name=name)
if module_and_name is None:
return False
module, name = module_and_name
return not _is_registered_pickle_by_value(module)
elif isinstance(obj, types.ModuleType):
# We assume that sys.modules is primarily used as a cache mechanism for
# the Python import machinery. Checking if a module has been added in
# is sys.modules therefore a cheap and simple heuristic to tell us
# whether we can assume that a given module could be imported by name
# in another Python process.
if _is_registered_pickle_by_value(obj):
return False
return obj.__name__ in sys.modules
else:
raise TypeError(
"cannot check importability of {} instances".format(
type(obj).__name__)
)
def _lookup_module_and_qualname(obj, name=None):
if name is None:
name = getattr(obj, '__qualname__', None)
if name is None: # pragma: no cover
# This used to be needed for Python 2.7 support but is probably not
# needed anymore. However we keep the __name__ introspection in case
# users of cloudpickle rely on this old behavior for unknown reasons.
name = getattr(obj, '__name__', None)
module_name = _whichmodule(obj, name)
if module_name is None:
# In this case, obj.__module__ is None AND obj was not found in any
# imported module. obj is thus treated as dynamic.
return None
if module_name == "__main__":
return None
# Note: if module_name is in sys.modules, the corresponding module is
# assumed importable at unpickling time. See #357
module = sys.modules.get(module_name, None)
if module is None:
# The main reason why obj's module would not be imported is that this
# module has been dynamically created, using for example
# types.ModuleType. The other possibility is that module was removed
# from sys.modules after obj was created/imported. But this case is not
# supported, as the standard pickle does not support it either.
return None
try:
obj2, parent = _getattribute(module, name)
except AttributeError:
# obj was not found inside the module it points to
return None
if obj2 is not obj:
return None
return module, name
def _extract_code_globals(co):
"""
Find all globals names read or written to by codeblock co
"""
out_names = _extract_code_globals_cache.get(co)
if out_names is None:
# We use a dict with None values instead of a set to get a
# deterministic order (assuming Python 3.6+) and avoid introducing
# non-deterministic pickle bytes as a results.
out_names = {name: None for name in _walk_global_ops(co)}
# Declaring a function inside another one using the "def ..."
# syntax generates a constant code object corresponding to the one
# of the nested function's As the nested function may itself need
# global variables, we need to introspect its code, extract its
# globals, (look for code object in it's co_consts attribute..) and
# add the result to code_globals
if co.co_consts:
for const in co.co_consts:
if isinstance(const, types.CodeType):
out_names.update(_extract_code_globals(const))
_extract_code_globals_cache[co] = out_names
return out_names
def _find_imported_submodules(code, top_level_dependencies):
"""
Find currently imported submodules used by a function.
Submodules used by a function need to be detected and referenced for the
function to work correctly at depickling time. Because submodules can be
referenced as attribute of their parent package (``package.submodule``), we
need a special introspection technique that does not rely on GLOBAL-related
opcodes to find references of them in a code object.
Example:
```
import concurrent.futures
import cloudpickle
def func():
x = concurrent.futures.ThreadPoolExecutor
if __name__ == '__main__':
cloudpickle.dumps(func)
```
The globals extracted by cloudpickle in the function's state include the
concurrent package, but not its submodule (here, concurrent.futures), which
is the module used by func. Find_imported_submodules will detect the usage
of concurrent.futures. Saving this module alongside with func will ensure
that calling func once depickled does not fail due to concurrent.futures
not being imported
"""
subimports = []
# check if any known dependency is an imported package
for x in top_level_dependencies:
if (isinstance(x, types.ModuleType) and
hasattr(x, '__package__') and x.__package__):
# check if the package has any currently loaded sub-imports
prefix = x.__name__ + '.'
# A concurrent thread could mutate sys.modules,
# make sure we iterate over a copy to avoid exceptions
for name in list(sys.modules):
# Older versions of pytest will add a "None" module to
# sys.modules.
if name is not None and name.startswith(prefix):
# check whether the function can address the sub-module
tokens = set(name[len(prefix):].split('.'))
if not tokens - set(code.co_names):
subimports.append(sys.modules[name])
return subimports
def cell_set(cell, value):
"""Set the value of a closure cell.
The point of this function is to set the cell_contents attribute of a cell
after its creation. This operation is necessary in case the cell contains a
reference to the function the cell belongs to, as when calling the
function's constructor
``f = types.FunctionType(code, globals, name, argdefs, closure)``,
closure will not be able to contain the yet-to-be-created f.
In Python3.7, cell_contents is writeable, so setting the contents of a cell
can be done simply using
>>> cell.cell_contents = value
In earlier Python3 versions, the cell_contents attribute of a cell is read
only, but this limitation can be worked around by leveraging the Python 3
``nonlocal`` keyword.
In Python2 however, this attribute is read only, and there is no
``nonlocal`` keyword. For this reason, we need to come up with more
complicated hacks to set this attribute.
The chosen approach is to create a function with a STORE_DEREF opcode,
which sets the content of a closure variable. Typically:
>>> def inner(value):
... lambda: cell # the lambda makes cell a closure
... cell = value # cell is a closure, so this triggers a STORE_DEREF
(Note that in Python2, A STORE_DEREF can never be triggered from an inner
function. The function g for example here
>>> def f(var):
... def g():
... var += 1
... return g
will not modify the closure variable ``var```inplace, but instead try to
load a local variable var and increment it. As g does not assign the local
variable ``var`` any initial value, calling f(1)() will fail at runtime.)
Our objective is to set the value of a given cell ``cell``. So we need to
somewhat reference our ``cell`` object into the ``inner`` function so that
this object (and not the smoke cell of the lambda function) gets affected
by the STORE_DEREF operation.
In inner, ``cell`` is referenced as a cell variable (an enclosing variable
that is referenced by the inner function). If we create a new function
cell_set with the exact same code as ``inner``, but with ``cell`` marked as
a free variable instead, the STORE_DEREF will be applied on its closure -
``cell``, which we can specify explicitly during construction! The new
cell_set variable thus actually sets the contents of a specified cell!
Note: we do not make use of the ``nonlocal`` keyword to set the contents of
a cell in early python3 versions to limit possible syntax errors in case
test and checker libraries decide to parse the whole file.
"""
if sys.version_info[:2] >= (3, 7): # pragma: no branch
cell.cell_contents = value
else:
_cell_set = types.FunctionType(
_cell_set_template_code, {}, '_cell_set', (), (cell,),)
_cell_set(value)
def _make_cell_set_template_code():
def _cell_set_factory(value):
lambda: cell
cell = value
co = _cell_set_factory.__code__
_cell_set_template_code = types.CodeType(
co.co_argcount,
co.co_kwonlyargcount, # Python 3 only argument
co.co_nlocals,
co.co_stacksize,
co.co_flags,
co.co_code,
co.co_consts,
co.co_names,
co.co_varnames,
co.co_filename,
co.co_name,
co.co_firstlineno,
co.co_lnotab,
co.co_cellvars, # co_freevars is initialized with co_cellvars
(), # co_cellvars is made empty
)
return _cell_set_template_code
if sys.version_info[:2] < (3, 7):
_cell_set_template_code = _make_cell_set_template_code()
# relevant opcodes
STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
EXTENDED_ARG = dis.EXTENDED_ARG
_BUILTIN_TYPE_NAMES = {}
for k, v in types.__dict__.items():
if type(v) is type:
_BUILTIN_TYPE_NAMES[v] = k
def _builtin_type(name):
if name == "ClassType": # pragma: no cover
# Backward compat to load pickle files generated with cloudpickle
# < 1.3 even if loading pickle files from older versions is not
# officially supported.
return type
return getattr(types, name)
def _walk_global_ops(code):
"""
Yield referenced name for all global-referencing instructions in *code*.
"""
for instr in dis.get_instructions(code):
op = instr.opcode
if op in GLOBAL_OPS:
yield instr.argval
def _extract_class_dict(cls):
"""Retrieve a copy of the dict of a class without the inherited methods"""
clsdict = dict(cls.__dict__) # copy dict proxy to a dict
if len(cls.__bases__) == 1:
inherited_dict = cls.__bases__[0].__dict__
else:
inherited_dict = {}
for base in reversed(cls.__bases__):
inherited_dict.update(base.__dict__)
to_remove = []
for name, value in clsdict.items():
try:
base_value = inherited_dict[name]
if value is base_value:
to_remove.append(name)
except KeyError:
pass
for name in to_remove:
clsdict.pop(name)
return clsdict
if sys.version_info[:2] < (3, 7): # pragma: no branch
def _is_parametrized_type_hint(obj):
# This is very cheap but might generate false positives. So try to
# narrow it down is good as possible.
type_module = getattr(type(obj), '__module__', None)
from_typing_extensions = type_module == 'typing_extensions'
from_typing = type_module == 'typing'
# general typing Constructs
is_typing = getattr(obj, '__origin__', None) is not None
# typing_extensions.Literal
is_literal = (
(getattr(obj, '__values__', None) is not None)
and from_typing_extensions
)
# typing_extensions.Final
is_final = (
(getattr(obj, '__type__', None) is not None)
and from_typing_extensions
)
# typing.ClassVar
is_classvar = (
(getattr(obj, '__type__', None) is not None) and from_typing
)
# typing.Union/Tuple for old Python 3.5
is_union = getattr(obj, '__union_params__', None) is not None
is_tuple = getattr(obj, '__tuple_params__', None) is not None
is_callable = (
getattr(obj, '__result__', None) is not None and
getattr(obj, '__args__', None) is not None
)
return any((is_typing, is_literal, is_final, is_classvar, is_union,
is_tuple, is_callable))
def _create_parametrized_type_hint(origin, args):
return origin[args]
else:
_is_parametrized_type_hint = None
_create_parametrized_type_hint = None
def parametrized_type_hint_getinitargs(obj):
# The distorted type check sematic for typing construct becomes:
# ``type(obj) is type(TypeHint)``, which means "obj is a
# parametrized TypeHint"
if type(obj) is type(Literal): # pragma: no branch
initargs = (Literal, obj.__values__)
elif type(obj) is type(Final): # pragma: no branch
initargs = (Final, obj.__type__)
elif type(obj) is type(ClassVar):
initargs = (ClassVar, obj.__type__)
elif type(obj) is type(Generic):
initargs = (obj.__origin__, obj.__args__)
elif type(obj) is type(Union):
initargs = (Union, obj.__args__)
elif type(obj) is type(Tuple):
initargs = (Tuple, obj.__args__)
elif type(obj) is type(Callable):
(*args, result) = obj.__args__
if len(args) == 1 and args[0] is Ellipsis:
args = Ellipsis
else:
args = list(args)
initargs = (Callable, (args, result))
else: # pragma: no cover
raise pickle.PicklingError(
f"Cloudpickle Error: Unknown type {type(obj)}"
)
return initargs
# Tornado support
def is_tornado_coroutine(func):
"""
Return whether *func* is a Tornado coroutine function.
Running coroutines are not supported.
"""
if 'tornado.gen' not in sys.modules:
return False
gen = sys.modules['tornado.gen']
if not hasattr(gen, "is_coroutine_function"):
# Tornado version is too old
return False
return gen.is_coroutine_function(func)
def _rebuild_tornado_coroutine(func):
from tornado import gen
return gen.coroutine(func)
# including pickles unloading functions in this namespace
load = pickle.load
loads = pickle.loads
def subimport(name):
# We cannot do simply: `return __import__(name)`: Indeed, if ``name`` is
# the name of a submodule, __import__ will return the top-level root module
# of this submodule. For instance, __import__('os.path') returns the `os`
# module.
__import__(name)
return sys.modules[name]
def dynamic_subimport(name, vars):
mod = types.ModuleType(name)
mod.__dict__.update(vars)
mod.__dict__['__builtins__'] = builtins.__dict__
return mod
def _gen_ellipsis():
return Ellipsis
def _gen_not_implemented():
return NotImplemented
def _get_cell_contents(cell):
try:
return cell.cell_contents
except ValueError:
# sentinel used by ``_fill_function`` which will leave the cell empty
return _empty_cell_value
def instance(cls):
"""Create a new instance of a class.
Parameters
----------
cls : type
The class to create an instance of.
Returns
-------
instance : cls
A new instance of ``cls``.
"""
return cls()
@instance
class _empty_cell_value:
"""sentinel for empty closures
"""
@classmethod
def __reduce__(cls):
return cls.__name__
def _fill_function(*args):
"""Fills in the rest of function data into the skeleton function object
The skeleton itself is create by _make_skel_func().
"""
if len(args) == 2:
func = args[0]
state = args[1]
elif len(args) == 5:
# Backwards compat for cloudpickle v0.4.0, after which the `module`
# argument was introduced
func = args[0]
keys = ['globals', 'defaults', 'dict', 'closure_values']
state = dict(zip(keys, args[1:]))
elif len(args) == 6:
# Backwards compat for cloudpickle v0.4.1, after which the function
# state was passed as a dict to the _fill_function it-self.
func = args[0]
keys = ['globals', 'defaults', 'dict', 'module', 'closure_values']
state = dict(zip(keys, args[1:]))
else:
raise ValueError(f'Unexpected _fill_value arguments: {args!r}')
# - At pickling time, any dynamic global variable used by func is
# serialized by value (in state['globals']).
# - At unpickling time, func's __globals__ attribute is initialized by
# first retrieving an empty isolated namespace that will be shared
# with other functions pickled from the same original module
# by the same CloudPickler instance and then updated with the
# content of state['globals'] to populate the shared isolated
# namespace with all the global variables that are specifically
# referenced for this function.
func.__globals__.update(state['globals'])
func.__defaults__ = state['defaults']
func.__dict__ = state['dict']
if 'annotations' in state:
func.__annotations__ = state['annotations']
if 'doc' in state:
func.__doc__ = state['doc']
if 'name' in state:
func.__name__ = state['name']
if 'module' in state:
func.__module__ = state['module']
if 'qualname' in state:
func.__qualname__ = state['qualname']
if 'kwdefaults' in state:
func.__kwdefaults__ = state['kwdefaults']
# _cloudpickle_subimports is a set of submodules that must be loaded for
# the pickled function to work correctly at unpickling time. Now that these
# submodules are depickled (hence imported), they can be removed from the
# object's state (the object state only served as a reference holder to
# these submodules)
if '_cloudpickle_submodules' in state:
state.pop('_cloudpickle_submodules')
cells = func.__closure__
if cells is not None:
for cell, value in zip(cells, state['closure_values']):
if value is not _empty_cell_value:
cell_set(cell, value)
return func
def _make_function(code, globals, name, argdefs, closure):
# Setting __builtins__ in globals is needed for nogil CPython.
globals["__builtins__"] = __builtins__
return types.FunctionType(code, globals, name, argdefs, closure)
def _make_empty_cell():
if False:
# trick the compiler into creating an empty cell in our lambda
cell = None
raise AssertionError('this route should not be executed')
return (lambda: cell).__closure__[0]
def _make_cell(value=_empty_cell_value):
cell = _make_empty_cell()
if value is not _empty_cell_value:
cell_set(cell, value)
return cell
def _make_skel_func(code, cell_count, base_globals=None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
"""
# This function is deprecated and should be removed in cloudpickle 1.7
warnings.warn(
"A pickle file created using an old (<=1.4.1) version of cloudpickle "
"is currently being loaded. This is not supported by cloudpickle and "
"will break in cloudpickle 1.7", category=UserWarning
)
# This is backward-compatibility code: for cloudpickle versions between
# 0.5.4 and 0.7, base_globals could be a string or None. base_globals
# should now always be a dictionary.
if base_globals is None or isinstance(base_globals, str):
base_globals = {}
base_globals['__builtins__'] = __builtins__
closure = (
tuple(_make_empty_cell() for _ in range(cell_count))
if cell_count >= 0 else
None
)
return types.FunctionType(code, base_globals, None, None, closure)
def _make_skeleton_class(type_constructor, name, bases, type_kwargs,
class_tracker_id, extra):
"""Build dynamic class with an empty __dict__ to be filled once memoized
If class_tracker_id is not None, try to lookup an existing class definition
matching that id. If none is found, track a newly reconstructed class
definition under that id so that other instances stemming from the same
class id will also reuse this class definition.
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
skeleton_class = types.new_class(
name, bases, {'metaclass': type_constructor},
lambda ns: ns.update(type_kwargs)
)
return _lookup_class_or_track(class_tracker_id, skeleton_class)
def _rehydrate_skeleton_class(skeleton_class, class_dict):
"""Put attributes from `class_dict` back on `skeleton_class`.
See CloudPickler.save_dynamic_class for more info.
"""
registry = None
for attrname, attr in class_dict.items():
if attrname == "_abc_impl":
registry = attr
else:
setattr(skeleton_class, attrname, attr)
if registry is not None:
for subclass in registry:
skeleton_class.register(subclass)
return skeleton_class
def _make_skeleton_enum(bases, name, qualname, members, module,
class_tracker_id, extra):
"""Build dynamic enum with an empty __dict__ to be filled once memoized
The creation of the enum class is inspired by the code of
EnumMeta._create_.
If class_tracker_id is not None, try to lookup an existing enum definition
matching that id. If none is found, track a newly reconstructed enum
definition under that id so that other instances stemming from the same
class id will also reuse this enum definition.
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
# enums always inherit from their base Enum class at the last position in
# the list of base classes:
enum_base = bases[-1]
metacls = enum_base.__class__
classdict = metacls.__prepare__(name, bases)
for member_name, member_value in members.items():
classdict[member_name] = member_value
enum_class = metacls.__new__(metacls, name, bases, classdict)
enum_class.__module__ = module
enum_class.__qualname__ = qualname
return _lookup_class_or_track(class_tracker_id, enum_class)
def _make_typevar(name, bound, constraints, covariant, contravariant,
class_tracker_id):
tv = typing.TypeVar(
name, *constraints, bound=bound,
covariant=covariant, contravariant=contravariant
)
if class_tracker_id is not None:
return _lookup_class_or_track(class_tracker_id, tv)
else: # pragma: nocover
# Only for Python 3.5.3 compat.
return tv
def _decompose_typevar(obj):
return (
obj.__name__, obj.__bound__, obj.__constraints__,
obj.__covariant__, obj.__contravariant__,
_get_or_create_tracker_id(obj),
)
def _typevar_reduce(obj):
# TypeVar instances require the module information hence why we
# are not using the _should_pickle_by_reference directly
module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__)
if module_and_name is None:
return (_make_typevar, _decompose_typevar(obj))
elif _is_registered_pickle_by_value(module_and_name[0]):
return (_make_typevar, _decompose_typevar(obj))
return (getattr, module_and_name)
def _get_bases(typ):
if '__orig_bases__' in getattr(typ, '__dict__', {}):
# For generic types (see PEP 560)
# Note that simply checking `hasattr(typ, '__orig_bases__')` is not
# correct. Subclasses of a fully-parameterized generic class does not
# have `__orig_bases__` defined, but `hasattr(typ, '__orig_bases__')`
# will return True because it's defined in the base class.
bases_attr = '__orig_bases__'
else:
# For regular class objects
bases_attr = '__bases__'
return getattr(typ, bases_attr)
def _make_dict_keys(obj, is_ordered=False):
if is_ordered:
return OrderedDict.fromkeys(obj).keys()
else:
return dict.fromkeys(obj).keys()
def _make_dict_values(obj, is_ordered=False):
if is_ordered:
return OrderedDict((i, _) for i, _ in enumerate(obj)).values()
else:
return {i: _ for i, _ in enumerate(obj)}.values()
def _make_dict_items(obj, is_ordered=False):
if is_ordered:
return OrderedDict(obj).items()
else:
return obj.items()
srsly-release-v2.5.1/srsly/cloudpickle/cloudpickle_fast.py 0000664 0000000 0000000 00000102502 14742310675 0024047 0 ustar 00root root 0000000 0000000 """
New, fast version of the CloudPickler.
This new CloudPickler class can now extend the fast C Pickler instead of the
previous Python implementation of the Pickler class. Because this functionality
is only available for Python versions 3.8+, a lot of backward-compatibility
code is also removed.
Note that the C Pickler subclassing API is CPython-specific. Therefore, some
guards present in cloudpickle.py that were written to handle PyPy specificities
are not present in cloudpickle_fast.py
"""
import _collections_abc
import abc
import copyreg
import io
import itertools
import logging
import sys
import struct
import types
import weakref
import typing
from enum import Enum
from collections import ChainMap, OrderedDict
from .compat import pickle, Pickler
from .cloudpickle import (
_extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
_find_imported_submodules, _get_cell_contents, _should_pickle_by_reference,
_builtin_type, _get_or_create_tracker_id, _make_skeleton_class,
_make_skeleton_enum, _extract_class_dict, dynamic_subimport, subimport,
_typevar_reduce, _get_bases, _make_cell, _make_empty_cell, CellType,
_is_parametrized_type_hint, PYPY, cell_set,
parametrized_type_hint_getinitargs, _create_parametrized_type_hint,
builtin_code_type,
_make_dict_keys, _make_dict_values, _make_dict_items, _make_function,
)
if pickle.HIGHEST_PROTOCOL >= 5:
# Shorthands similar to pickle.dump/pickle.dumps
def dump(obj, file, protocol=None, buffer_callback=None):
"""Serialize obj as bytes streamed into file
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
speed between processes running the same Python version.
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
CloudPickler(
file, protocol=protocol, buffer_callback=buffer_callback
).dump(obj)
def dumps(obj, protocol=None, buffer_callback=None):
"""Serialize obj as a string of bytes allocated in memory
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
speed between processes running the same Python version.
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
with io.BytesIO() as file:
cp = CloudPickler(
file, protocol=protocol, buffer_callback=buffer_callback
)
cp.dump(obj)
return file.getvalue()
else:
# Shorthands similar to pickle.dump/pickle.dumps
def dump(obj, file, protocol=None):
"""Serialize obj as bytes streamed into file
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
speed between processes running the same Python version.
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
CloudPickler(file, protocol=protocol).dump(obj)
def dumps(obj, protocol=None):
"""Serialize obj as a string of bytes allocated in memory
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
speed between processes running the same Python version.
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
with io.BytesIO() as file:
cp = CloudPickler(file, protocol=protocol)
cp.dump(obj)
return file.getvalue()
load, loads = pickle.load, pickle.loads
# COLLECTION OF OBJECTS __getnewargs__-LIKE METHODS
# -------------------------------------------------
def _class_getnewargs(obj):
type_kwargs = {}
if "__slots__" in obj.__dict__:
type_kwargs["__slots__"] = obj.__slots__
__dict__ = obj.__dict__.get('__dict__', None)
if isinstance(__dict__, property):
type_kwargs['__dict__'] = __dict__
return (type(obj), obj.__name__, _get_bases(obj), type_kwargs,
_get_or_create_tracker_id(obj), None)
def _enum_getnewargs(obj):
members = {e.name: e.value for e in obj}
return (obj.__bases__, obj.__name__, obj.__qualname__, members,
obj.__module__, _get_or_create_tracker_id(obj), None)
# COLLECTION OF OBJECTS RECONSTRUCTORS
# ------------------------------------
def _file_reconstructor(retval):
return retval
# COLLECTION OF OBJECTS STATE GETTERS
# -----------------------------------
def _function_getstate(func):
# - Put func's dynamic attributes (stored in func.__dict__) in state. These
# attributes will be restored at unpickling time using
# f.__dict__.update(state)
# - Put func's members into slotstate. Such attributes will be restored at
# unpickling time by iterating over slotstate and calling setattr(func,
# slotname, slotvalue)
slotstate = {
"__name__": func.__name__,
"__qualname__": func.__qualname__,
"__annotations__": func.__annotations__,
"__kwdefaults__": func.__kwdefaults__,
"__defaults__": func.__defaults__,
"__module__": func.__module__,
"__doc__": func.__doc__,
"__closure__": func.__closure__,
}
f_globals_ref = _extract_code_globals(func.__code__)
f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in
func.__globals__}
closure_values = (
list(map(_get_cell_contents, func.__closure__))
if func.__closure__ is not None else ()
)
# Extract currently-imported submodules used by func. Storing these modules
# in a smoke _cloudpickle_subimports attribute of the object's state will
# trigger the side effect of importing these modules at unpickling time
# (which is necessary for func to work correctly once depickled)
slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
func.__code__, itertools.chain(f_globals.values(), closure_values))
slotstate["__globals__"] = f_globals
state = func.__dict__
return state, slotstate
def _class_getstate(obj):
clsdict = _extract_class_dict(obj)
clsdict.pop('__weakref__', None)
if issubclass(type(obj), abc.ABCMeta):
# If obj is an instance of an ABCMeta subclass, don't pickle the
# cache/negative caches populated during isinstance/issubclass
# checks, but pickle the list of registered subclasses of obj.
clsdict.pop('_abc_cache', None)
clsdict.pop('_abc_negative_cache', None)
clsdict.pop('_abc_negative_cache_version', None)
registry = clsdict.pop('_abc_registry', None)
if registry is None:
# in Python3.7+, the abc caches and registered subclasses of a
# class are bundled into the single _abc_impl attribute
clsdict.pop('_abc_impl', None)
(registry, _, _, _) = abc._get_dump(obj)
clsdict["_abc_impl"] = [subclass_weakref()
for subclass_weakref in registry]
else:
# In the above if clause, registry is a set of weakrefs -- in
# this case, registry is a WeakSet
clsdict["_abc_impl"] = [type_ for type_ in registry]
if "__slots__" in clsdict:
# pickle string length optimization: member descriptors of obj are
# created automatically from obj's __slots__ attribute, no need to
# save them in obj's state
if isinstance(obj.__slots__, str):
clsdict.pop(obj.__slots__)
else:
for k in obj.__slots__:
clsdict.pop(k, None)
clsdict.pop('__dict__', None) # unpicklable property object
return (clsdict, {})
def _enum_getstate(obj):
clsdict, slotstate = _class_getstate(obj)
members = {e.name: e.value for e in obj}
# Cleanup the clsdict that will be passed to _rehydrate_skeleton_class:
# Those attributes are already handled by the metaclass.
for attrname in ["_generate_next_value_", "_member_names_",
"_member_map_", "_member_type_",
"_value2member_map_"]:
clsdict.pop(attrname, None)
for member in members:
clsdict.pop(member)
# Special handling of Enum subclasses
return clsdict, slotstate
# COLLECTIONS OF OBJECTS REDUCERS
# -------------------------------
# A reducer is a function taking a single argument (obj), and that returns a
# tuple with all the necessary data to re-construct obj. Apart from a few
# exceptions (list, dict, bytes, int, etc.), a reducer is necessary to
# correctly pickle an object.
# While many built-in objects (Exceptions objects, instances of the "object"
# class, etc), are shipped with their own built-in reducer (invoked using
# obj.__reduce__), some do not. The following methods were created to "fill
# these holes".
def _code_reduce(obj):
"""codeobject reducer"""
# If you are not sure about the order of arguments, take a look at help
# of the specific type from types, for example:
# >>> from types import CodeType
# >>> help(CodeType)
if hasattr(obj, "co_exceptiontable"): # pragma: no branch
# Python 3.11 and later: there are some new attributes
# related to the enhanced exceptions.
args = (
obj.co_argcount, obj.co_posonlyargcount,
obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
obj.co_varnames, obj.co_filename, obj.co_name, obj.co_qualname,
obj.co_firstlineno, obj.co_linetable, obj.co_exceptiontable,
obj.co_freevars, obj.co_cellvars,
)
elif hasattr(obj, "co_linetable"): # pragma: no branch
# Python 3.10 and later: obj.co_lnotab is deprecated and constructor
# expects obj.co_linetable instead.
args = (
obj.co_argcount, obj.co_posonlyargcount,
obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
obj.co_varnames, obj.co_filename, obj.co_name,
obj.co_firstlineno, obj.co_linetable, obj.co_freevars,
obj.co_cellvars
)
elif hasattr(obj, "co_nmeta"): # pragma: no cover
# "nogil" Python: modified attributes from 3.9
args = (
obj.co_argcount, obj.co_posonlyargcount,
obj.co_kwonlyargcount, obj.co_nlocals, obj.co_framesize,
obj.co_ndefaultargs, obj.co_nmeta,
obj.co_flags, obj.co_code, obj.co_consts,
obj.co_varnames, obj.co_filename, obj.co_name,
obj.co_firstlineno, obj.co_lnotab, obj.co_exc_handlers,
obj.co_jump_table, obj.co_freevars, obj.co_cellvars,
obj.co_free2reg, obj.co_cell2reg
)
elif hasattr(obj, "co_posonlyargcount"):
# Backward compat for 3.9 and older
args = (
obj.co_argcount, obj.co_posonlyargcount,
obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
obj.co_varnames, obj.co_filename, obj.co_name,
obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
obj.co_cellvars
)
else:
# Backward compat for even older versions of Python
args = (
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals,
obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts,
obj.co_names, obj.co_varnames, obj.co_filename,
obj.co_name, obj.co_firstlineno, obj.co_lnotab,
obj.co_freevars, obj.co_cellvars
)
return types.CodeType, args
def _cell_reduce(obj):
"""Cell (containing values of a function's free variables) reducer"""
try:
obj.cell_contents
except ValueError: # cell is empty
return _make_empty_cell, ()
else:
return _make_cell, (obj.cell_contents, )
def _classmethod_reduce(obj):
orig_func = obj.__func__
return type(obj), (orig_func,)
def _file_reduce(obj):
"""Save a file"""
import io
if not hasattr(obj, "name") or not hasattr(obj, "mode"):
raise pickle.PicklingError(
"Cannot pickle files that do not map to an actual file"
)
if obj is sys.stdout:
return getattr, (sys, "stdout")
if obj is sys.stderr:
return getattr, (sys, "stderr")
if obj is sys.stdin:
raise pickle.PicklingError("Cannot pickle standard input")
if obj.closed:
raise pickle.PicklingError("Cannot pickle closed files")
if hasattr(obj, "isatty") and obj.isatty():
raise pickle.PicklingError(
"Cannot pickle files that map to tty objects"
)
if "r" not in obj.mode and "+" not in obj.mode:
raise pickle.PicklingError(
"Cannot pickle files that are not opened for reading: %s"
% obj.mode
)
name = obj.name
retval = io.StringIO()
try:
# Read the whole file
curloc = obj.tell()
obj.seek(0)
contents = obj.read()
obj.seek(curloc)
except IOError as e:
raise pickle.PicklingError(
"Cannot pickle file %s as it cannot be read" % name
) from e
retval.write(contents)
retval.seek(curloc)
retval.name = name
return _file_reconstructor, (retval,)
def _getset_descriptor_reduce(obj):
return getattr, (obj.__objclass__, obj.__name__)
def _mappingproxy_reduce(obj):
return types.MappingProxyType, (dict(obj),)
def _memoryview_reduce(obj):
return bytes, (obj.tobytes(),)
def _module_reduce(obj):
if _should_pickle_by_reference(obj):
return subimport, (obj.__name__,)
else:
# Some external libraries can populate the "__builtins__" entry of a
# module's `__dict__` with unpicklable objects (see #316). For that
# reason, we do not attempt to pickle the "__builtins__" entry, and
# restore a default value for it at unpickling time.
state = obj.__dict__.copy()
state.pop('__builtins__', None)
return dynamic_subimport, (obj.__name__, state)
def _method_reduce(obj):
return (types.MethodType, (obj.__func__, obj.__self__))
def _logger_reduce(obj):
return logging.getLogger, (obj.name,)
def _root_logger_reduce(obj):
return logging.getLogger, ()
def _property_reduce(obj):
return property, (obj.fget, obj.fset, obj.fdel, obj.__doc__)
def _weakset_reduce(obj):
return weakref.WeakSet, (list(obj),)
def _dynamic_class_reduce(obj):
"""
Save a class that can't be stored as module global.
This method is used to serialize classes that are defined inside
functions, or that otherwise can't be serialized as attribute lookups
from global modules.
"""
if Enum is not None and issubclass(obj, Enum):
return (
_make_skeleton_enum, _enum_getnewargs(obj), _enum_getstate(obj),
None, None, _class_setstate
)
else:
return (
_make_skeleton_class, _class_getnewargs(obj), _class_getstate(obj),
None, None, _class_setstate
)
def _class_reduce(obj):
"""Select the reducer depending on the dynamic nature of the class obj"""
if obj is type(None): # noqa
return type, (None,)
elif obj is type(Ellipsis):
return type, (Ellipsis,)
elif obj is type(NotImplemented):
return type, (NotImplemented,)
elif obj in _BUILTIN_TYPE_NAMES:
return _builtin_type, (_BUILTIN_TYPE_NAMES[obj],)
elif not _should_pickle_by_reference(obj):
return _dynamic_class_reduce(obj)
return NotImplemented
def _dict_keys_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_keys, (list(obj), )
def _dict_values_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_values, (list(obj), )
def _dict_items_reduce(obj):
return _make_dict_items, (dict(obj), )
def _odict_keys_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_keys, (list(obj), True)
def _odict_values_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_values, (list(obj), True)
def _odict_items_reduce(obj):
return _make_dict_items, (dict(obj), True)
# COLLECTIONS OF OBJECTS STATE SETTERS
# ------------------------------------
# state setters are called at unpickling time, once the object is created and
# it has to be updated to how it was at unpickling time.
def _function_setstate(obj, state):
"""Update the state of a dynamic function.
As __closure__ and __globals__ are readonly attributes of a function, we
cannot rely on the native setstate routine of pickle.load_build, that calls
setattr on items of the slotstate. Instead, we have to modify them inplace.
"""
state, slotstate = state
obj.__dict__.update(state)
obj_globals = slotstate.pop("__globals__")
obj_closure = slotstate.pop("__closure__")
# _cloudpickle_subimports is a set of submodules that must be loaded for
# the pickled function to work correctly at unpickling time. Now that these
# submodules are depickled (hence imported), they can be removed from the
# object's state (the object state only served as a reference holder to
# these submodules)
slotstate.pop("_cloudpickle_submodules")
obj.__globals__.update(obj_globals)
obj.__globals__["__builtins__"] = __builtins__
if obj_closure is not None:
for i, cell in enumerate(obj_closure):
try:
value = cell.cell_contents
except ValueError: # cell is empty
continue
cell_set(obj.__closure__[i], value)
for k, v in slotstate.items():
setattr(obj, k, v)
def _class_setstate(obj, state):
state, slotstate = state
registry = None
for attrname, attr in state.items():
if attrname == "_abc_impl":
registry = attr
else:
setattr(obj, attrname, attr)
if registry is not None:
for subclass in registry:
obj.register(subclass)
return obj
class CloudPickler(Pickler):
# set of reducers defined and used by cloudpickle (private)
_dispatch_table = {}
_dispatch_table[classmethod] = _classmethod_reduce
_dispatch_table[io.TextIOWrapper] = _file_reduce
_dispatch_table[logging.Logger] = _logger_reduce
_dispatch_table[logging.RootLogger] = _root_logger_reduce
_dispatch_table[memoryview] = _memoryview_reduce
_dispatch_table[property] = _property_reduce
_dispatch_table[staticmethod] = _classmethod_reduce
_dispatch_table[CellType] = _cell_reduce
_dispatch_table[types.CodeType] = _code_reduce
_dispatch_table[types.GetSetDescriptorType] = _getset_descriptor_reduce
_dispatch_table[types.ModuleType] = _module_reduce
_dispatch_table[types.MethodType] = _method_reduce
_dispatch_table[types.MappingProxyType] = _mappingproxy_reduce
_dispatch_table[weakref.WeakSet] = _weakset_reduce
_dispatch_table[typing.TypeVar] = _typevar_reduce
_dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce
_dispatch_table[_collections_abc.dict_values] = _dict_values_reduce
_dispatch_table[_collections_abc.dict_items] = _dict_items_reduce
_dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce
_dispatch_table[type(OrderedDict().values())] = _odict_values_reduce
_dispatch_table[type(OrderedDict().items())] = _odict_items_reduce
_dispatch_table[abc.abstractmethod] = _classmethod_reduce
_dispatch_table[abc.abstractclassmethod] = _classmethod_reduce
_dispatch_table[abc.abstractstaticmethod] = _classmethod_reduce
_dispatch_table[abc.abstractproperty] = _property_reduce
dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table)
# function reducers are defined as instance methods of CloudPickler
# objects, as they rely on a CloudPickler attribute (globals_ref)
def _dynamic_function_reduce(self, func):
"""Reduce a function that is not pickleable via attribute lookup."""
newargs = self._function_getnewargs(func)
state = _function_getstate(func)
return (_make_function, newargs, state, None, None,
_function_setstate)
def _function_reduce(self, obj):
"""Reducer for function objects.
If obj is a top-level attribute of a file-backed module, this
reducer returns NotImplemented, making the CloudPickler fallback to
traditional _pickle.Pickler routines to save obj. Otherwise, it reduces
obj using a custom cloudpickle reducer designed specifically to handle
dynamic functions.
As opposed to cloudpickle.py, There no special handling for builtin
pypy functions because cloudpickle_fast is CPython-specific.
"""
if _should_pickle_by_reference(obj):
return NotImplemented
else:
return self._dynamic_function_reduce(obj)
def _function_getnewargs(self, func):
code = func.__code__
# base_globals represents the future global namespace of func at
# unpickling time. Looking it up and storing it in
# CloudpiPickler.globals_ref allow functions sharing the same globals
# at pickling time to also share them once unpickled, at one condition:
# since globals_ref is an attribute of a CloudPickler instance, and
# that a new CloudPickler is created each time pickle.dump or
# pickle.dumps is called, functions also need to be saved within the
# same invocation of cloudpickle.dump/cloudpickle.dumps (for example:
# cloudpickle.dumps([f1, f2])). There is no such limitation when using
# CloudPickler.dump, as long as the multiple invocations are bound to
# the same CloudPickler.
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
if base_globals == {}:
# Add module attributes used to resolve relative imports
# instructions inside func.
for k in ["__package__", "__name__", "__path__", "__file__"]:
if k in func.__globals__:
base_globals[k] = func.__globals__[k]
# Do not bind the free variables before the function is created to
# avoid infinite recursion.
if func.__closure__ is None:
closure = None
else:
closure = tuple(
_make_empty_cell() for _ in range(len(code.co_freevars)))
return code, base_globals, None, None, closure
def dump(self, obj):
try:
return Pickler.dump(self, obj)
except RuntimeError as e:
if "recursion" in e.args[0]:
msg = (
"Could not pickle object as excessively deep recursion "
"required."
)
raise pickle.PicklingError(msg) from e
else:
raise
if pickle.HIGHEST_PROTOCOL >= 5:
def __init__(self, file, protocol=None, buffer_callback=None):
if protocol is None:
protocol = DEFAULT_PROTOCOL
Pickler.__init__(
self, file, protocol=protocol, buffer_callback=buffer_callback
)
# map functions __globals__ attribute ids, to ensure that functions
# sharing the same global namespace at pickling time also share
# their global namespace at unpickling time.
self.globals_ref = {}
self.proto = int(protocol)
else:
def __init__(self, file, protocol=None):
if protocol is None:
protocol = DEFAULT_PROTOCOL
Pickler.__init__(self, file, protocol=protocol)
# map functions __globals__ attribute ids, to ensure that functions
# sharing the same global namespace at pickling time also share
# their global namespace at unpickling time.
self.globals_ref = {}
assert hasattr(self, 'proto')
if pickle.HIGHEST_PROTOCOL >= 5 and not PYPY:
# Pickler is the C implementation of the CPython pickler and therefore
# we rely on reduce_override method to customize the pickler behavior.
# `CloudPickler.dispatch` is only left for backward compatibility - note
# that when using protocol 5, `CloudPickler.dispatch` is not an
# extension of `Pickler.dispatch` dictionary, because CloudPickler
# subclasses the C-implemented Pickler, which does not expose a
# `dispatch` attribute. Earlier versions of the protocol 5 CloudPickler
# used `CloudPickler.dispatch` as a class-level attribute storing all
# reducers implemented by cloudpickle, but the attribute name was not a
# great choice given the meaning of `CloudPickler.dispatch` when
# `CloudPickler` extends the pure-python pickler.
dispatch = dispatch_table
# Implementation of the reducer_override callback, in order to
# efficiently serialize dynamic functions and classes by subclassing
# the C-implemented Pickler.
# TODO: decorrelate reducer_override (which is tied to CPython's
# implementation - would it make sense to backport it to pypy? - and
# pickle's protocol 5 which is implementation agnostic. Currently, the
# availability of both notions coincide on CPython's pickle and the
# pickle5 backport, but it may not be the case anymore when pypy
# implements protocol 5
def reducer_override(self, obj):
"""Type-agnostic reducing callback for function and classes.
For performance reasons, subclasses of the C _pickle.Pickler class
cannot register custom reducers for functions and classes in the
dispatch_table. Reducer for such types must instead implemented in
the special reducer_override method.
Note that method will be called for any object except a few
builtin-types (int, lists, dicts etc.), which differs from reducers
in the Pickler's dispatch_table, each of them being invoked for
objects of a specific type only.
This property comes in handy for classes: although most classes are
instances of the ``type`` metaclass, some of them can be instances
of other custom metaclasses (such as enum.EnumMeta for example). In
particular, the metaclass will likely not be known in advance, and
thus cannot be special-cased using an entry in the dispatch_table.
reducer_override, among other things, allows us to register a
reducer that will be called for any class, independently of its
type.
Notes:
* reducer_override has the priority over dispatch_table-registered
reducers.
* reducer_override can be used to fix other limitations of
cloudpickle for other types that suffered from type-specific
reducers, such as Exceptions. See
https://github.com/cloudpipe/cloudpickle/issues/248
"""
if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
return (
_create_parametrized_type_hint,
parametrized_type_hint_getinitargs(obj)
)
t = type(obj)
try:
is_anyclass = issubclass(t, type)
except TypeError: # t is not a class (old Boost; see SF #502085)
is_anyclass = False
if is_anyclass:
return _class_reduce(obj)
elif isinstance(obj, types.FunctionType):
return self._function_reduce(obj)
else:
# fallback to save_global, including the Pickler's
# dispatch_table
return NotImplemented
else:
# When reducer_override is not available, hack the pure-Python
# Pickler's types.FunctionType and type savers. Note: the type saver
# must override Pickler.save_global, because pickle.py contains a
# hard-coded call to save_global when pickling meta-classes.
dispatch = Pickler.dispatch.copy()
def _save_reduce_pickle5(self, func, args, state=None, listitems=None,
dictitems=None, state_setter=None, obj=None):
save = self.save
write = self.write
self.save_reduce(
func, args, state=None, listitems=listitems,
dictitems=dictitems, obj=obj
)
# backport of the Python 3.8 state_setter pickle operations
save(state_setter)
save(obj) # simple BINGET opcode as obj is already memoized.
save(state)
write(pickle.TUPLE2)
# Trigger a state_setter(obj, state) function call.
write(pickle.REDUCE)
# The purpose of state_setter is to carry-out an
# inplace modification of obj. We do not care about what the
# method might return, so its output is eventually removed from
# the stack.
write(pickle.POP)
def save_global(self, obj, name=None, pack=struct.pack):
"""
Save a "global".
The name of this method is somewhat misleading: all types get
dispatched here.
"""
if obj is type(None): # noqa
return self.save_reduce(type, (None,), obj=obj)
elif obj is type(Ellipsis):
return self.save_reduce(type, (Ellipsis,), obj=obj)
elif obj is type(NotImplemented):
return self.save_reduce(type, (NotImplemented,), obj=obj)
elif obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
# Parametrized typing constructs in Python < 3.7 are not
# compatible with type checks and ``isinstance`` semantics. For
# this reason, it is easier to detect them using a
# duck-typing-based check (``_is_parametrized_type_hint``) than
# to populate the Pickler's dispatch with type-specific savers.
self.save_reduce(
_create_parametrized_type_hint,
parametrized_type_hint_getinitargs(obj),
obj=obj
)
elif name is not None:
Pickler.save_global(self, obj, name=name)
elif not _should_pickle_by_reference(obj, name=name):
self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj)
else:
Pickler.save_global(self, obj, name=name)
dispatch[type] = save_global
def save_function(self, obj, name=None):
""" Registered with the dispatch to handle all function types.
Determines what kind of function obj is (e.g. lambda, defined at
interactive prompt, etc) and handles the pickling appropriately.
"""
if _should_pickle_by_reference(obj, name=name):
return Pickler.save_global(self, obj, name=name)
elif PYPY and isinstance(obj.__code__, builtin_code_type):
return self.save_pypy_builtin_func(obj)
else:
return self._save_reduce_pickle5(
*self._dynamic_function_reduce(obj), obj=obj
)
def save_pypy_builtin_func(self, obj):
"""Save pypy equivalent of builtin functions.
PyPy does not have the concept of builtin-functions. Instead,
builtin-functions are simple function instances, but with a
builtin-code attribute.
Most of the time, builtin functions should be pickled by attribute.
But PyPy has flaky support for __qualname__, so some builtin
functions such as float.__new__ will be classified as dynamic. For
this reason only, we created this special routine. Because
builtin-functions are not expected to have closure or globals,
there is no additional hack (compared the one already implemented
in pickle) to protect ourselves from reference cycles. A simple
(reconstructor, newargs, obj.__dict__) tuple is save_reduced. Note
also that PyPy improved their support for __qualname__ in v3.6, so
this routing should be removed when cloudpickle supports only PyPy
3.6 and later.
"""
rv = (types.FunctionType, (obj.__code__, {}, obj.__name__,
obj.__defaults__, obj.__closure__),
obj.__dict__)
self.save_reduce(*rv, obj=obj)
dispatch[types.FunctionType] = save_function
srsly-release-v2.5.1/srsly/cloudpickle/compat.py 0000664 0000000 0000000 00000000774 14742310675 0022027 0 ustar 00root root 0000000 0000000 import sys
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
from pickle5 import Pickler # noqa: F401
except ImportError:
import pickle # noqa: F401
# Use the Python pickler for old CPython versions
from pickle import _Pickler as Pickler # noqa: F401
else:
import pickle # noqa: F401
# Pickler will the C implementation in CPython and the Python
# implementation in PyPy
from pickle import Pickler # noqa: F401
srsly-release-v2.5.1/srsly/msgpack/ 0000775 0000000 0000000 00000000000 14742310675 0017311 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/msgpack/__init__.py 0000664 0000000 0000000 00000005324 14742310675 0021426 0 ustar 00root root 0000000 0000000 # coding: utf-8
import functools
import catalogue
# These need to be imported before packer and unpacker
from ._epoch import utc, epoch # noqa
from ._version import version
from .exceptions import *
# In msgpack-python these are put under a _cmsgpack module that textually includes
# them. I dislike this so I refactored it.
from ._packer import Packer as _Packer
from ._unpacker import unpackb as _unpackb
from ._unpacker import Unpacker as _Unpacker
from .ext import ExtType
from ._msgpack_numpy import encode_numpy as _encode_numpy
from ._msgpack_numpy import decode_numpy as _decode_numpy
msgpack_encoders = catalogue.create("srsly", "msgpack_encoders", entry_points=True)
msgpack_decoders = catalogue.create("srsly", "msgpack_decoders", entry_points=True)
msgpack_encoders.register("numpy", func=_encode_numpy)
msgpack_decoders.register("numpy", func=_decode_numpy)
# msgpack_numpy extensions
class Packer(_Packer):
def __init__(self, *args, **kwargs):
default = kwargs.get("default")
for encoder in msgpack_encoders.get_all().values():
default = functools.partial(encoder, chain=default)
kwargs["default"] = default
super(Packer, self).__init__(*args, **kwargs)
class Unpacker(_Unpacker):
def __init__(self, *args, **kwargs):
object_hook = kwargs.get("object_hook")
for decoder in msgpack_decoders.get_all().values():
object_hook = functools.partial(decoder, chain=object_hook)
kwargs["object_hook"] = object_hook
super(Unpacker, self).__init__(*args, **kwargs)
def pack(o, stream, **kwargs):
"""
Pack an object and write it to a stream.
"""
packer = Packer(**kwargs)
stream.write(packer.pack(o))
def packb(o, **kwargs):
"""
Pack an object and return the packed bytes.
"""
return Packer(**kwargs).pack(o)
def unpack(stream, **kwargs):
"""
Unpack a packed object from a stream.
"""
if "object_pairs_hook" not in kwargs:
object_hook = kwargs.get("object_hook")
for decoder in msgpack_decoders.get_all().values():
object_hook = functools.partial(decoder, chain=object_hook)
kwargs["object_hook"] = object_hook
data = stream.read()
return _unpackb(data, **kwargs)
def unpackb(packed, **kwargs):
"""
Unpack a packed object.
"""
if "object_pairs_hook" not in kwargs:
object_hook = kwargs.get("object_hook")
for decoder in msgpack_decoders.get_all().values():
object_hook = functools.partial(decoder, chain=object_hook)
kwargs["object_hook"] = object_hook
return _unpackb(packed, **kwargs)
# alias for compatibility to simplejson/marshal/pickle.
load = unpack
loads = unpackb
dump = pack
dumps = packb
srsly-release-v2.5.1/srsly/msgpack/_epoch.pyx 0000664 0000000 0000000 00000000261 14742310675 0021307 0 ustar 00root root 0000000 0000000 from cpython.datetime cimport import_datetime, datetime_new
import_datetime()
import datetime
utc = datetime.timezone.utc
epoch = datetime_new(1970, 1, 1, 0, 0, 0, 0, tz=utc)
srsly-release-v2.5.1/srsly/msgpack/_msgpack_numpy.py 0000664 0000000 0000000 00000005235 14742310675 0022704 0 ustar 00root root 0000000 0000000 #!/usr/bin/env python
"""
Support for serialization of numpy data types with msgpack.
"""
# Copyright (c) 2013-2018, Lev E. Givon
# All rights reserved.
# Distributed under the terms of the BSD license:
# http://www.opensource.org/licenses/bsd-license
try:
import numpy as np
has_numpy = True
except ImportError:
has_numpy = False
try:
import cupy
has_cupy = True
except ImportError:
has_cupy = False
def encode_numpy(obj, chain=None):
"""
Data encoder for serializing numpy data types.
"""
if not has_numpy:
return obj if chain is None else chain(obj)
if has_cupy and isinstance(obj, cupy.ndarray):
obj = obj.get()
if isinstance(obj, np.ndarray):
# If the dtype is structured, store the interface description;
# otherwise, store the corresponding array protocol type string:
if obj.dtype.kind == "V":
kind = b"V"
descr = obj.dtype.descr
else:
kind = b""
descr = obj.dtype.str
return {
b"nd": True,
b"type": descr,
b"kind": kind,
b"shape": obj.shape,
b"data": obj.data if obj.flags["C_CONTIGUOUS"] else obj.tobytes(),
}
elif isinstance(obj, (np.bool_, np.number)):
return {b"nd": False, b"type": obj.dtype.str, b"data": obj.data}
elif isinstance(obj, complex):
return {b"complex": True, b"data": obj.__repr__()}
else:
return obj if chain is None else chain(obj)
def tostr(x):
if isinstance(x, bytes):
return x.decode()
else:
return str(x)
def decode_numpy(obj, chain=None):
"""
Decoder for deserializing numpy data types.
"""
try:
if b"nd" in obj:
if obj[b"nd"] is True:
# Check if b'kind' is in obj to enable decoding of data
# serialized with older versions (#20):
if b"kind" in obj and obj[b"kind"] == b"V":
descr = [
tuple(tostr(t) if type(t) is bytes else t for t in d)
for d in obj[b"type"]
]
else:
descr = obj[b"type"]
return np.frombuffer(obj[b"data"], dtype=np.dtype(descr)).reshape(
obj[b"shape"]
)
else:
descr = obj[b"type"]
return np.frombuffer(obj[b"data"], dtype=np.dtype(descr))[0]
elif b"complex" in obj:
return complex(tostr(obj[b"data"]))
else:
return obj if chain is None else chain(obj)
except KeyError:
return obj if chain is None else chain(obj)
srsly-release-v2.5.1/srsly/msgpack/_packer.pyx 0000664 0000000 0000000 00000033365 14742310675 0021471 0 ustar 00root root 0000000 0000000 # coding: utf-8
from cpython cimport *
from cpython.bytearray cimport PyByteArray_Check, PyByteArray_CheckExact
from cpython.datetime cimport (
PyDateTime_CheckExact, PyDelta_CheckExact,
datetime_tzinfo, timedelta_days, timedelta_seconds, timedelta_microseconds,
)
from ._epoch import utc, epoch
cdef ExtType
cdef Timestamp
from .ext import ExtType, Timestamp
from .util import ensure_bytes
cdef extern from "Python.h":
int PyMemoryView_Check(object obj)
cdef extern from "pack.h":
struct msgpack_packer:
char* buf
size_t length
size_t buf_size
bint use_bin_type
int msgpack_pack_nil(msgpack_packer* pk) except -1
int msgpack_pack_true(msgpack_packer* pk) except -1
int msgpack_pack_false(msgpack_packer* pk) except -1
int msgpack_pack_long_long(msgpack_packer* pk, long long d) except -1
int msgpack_pack_unsigned_long_long(msgpack_packer* pk, unsigned long long d) except -1
int msgpack_pack_float(msgpack_packer* pk, float d) except -1
int msgpack_pack_double(msgpack_packer* pk, double d) except -1
int msgpack_pack_array(msgpack_packer* pk, size_t l) except -1
int msgpack_pack_map(msgpack_packer* pk, size_t l) except -1
int msgpack_pack_raw(msgpack_packer* pk, size_t l) except -1
int msgpack_pack_bin(msgpack_packer* pk, size_t l) except -1
int msgpack_pack_raw_body(msgpack_packer* pk, char* body, size_t l) except -1
int msgpack_pack_ext(msgpack_packer* pk, char typecode, size_t l) except -1
int msgpack_pack_timestamp(msgpack_packer* x, long long seconds, unsigned long nanoseconds) except -1
cdef int DEFAULT_RECURSE_LIMIT=511
cdef long long ITEM_LIMIT = (2**32)-1
cdef inline int PyBytesLike_Check(object o):
return PyBytes_Check(o) or PyByteArray_Check(o)
cdef inline int PyBytesLike_CheckExact(object o):
return PyBytes_CheckExact(o) or PyByteArray_CheckExact(o)
cdef class Packer:
"""
MessagePack Packer
Usage::
packer = Packer()
astream.write(packer.pack(a))
astream.write(packer.pack(b))
Packer's constructor has some keyword arguments:
:param default:
When specified, it should be callable.
Convert user type to builtin type that Packer supports.
See also simplejson's document.
:param bool use_single_float:
Use single precision float type for float. (default: False)
:param bool autoreset:
Reset buffer after each pack and return its content as `bytes`. (default: True).
If set this to false, use `bytes()` to get content and `.reset()` to clear buffer.
:param bool use_bin_type:
Use bin type introduced in msgpack spec 2.0 for bytes.
It also enables str8 type for unicode. (default: True)
:param bool strict_types:
If set to true, types will be checked to be exact. Derived classes
from serializeable types will not be serialized and will be
treated as unsupported type and forwarded to default.
Additionally tuples will not be serialized as lists.
This is useful when trying to implement accurate serialization
for python types.
:param bool datetime:
If set to true, datetime with tzinfo is packed into Timestamp type.
Note that the tzinfo is stripped in the timestamp.
You can get UTC datetime with `timestamp=3` option of the Unpacker.
:param str unicode_errors:
The error handler for encoding unicode. (default: 'strict')
DO NOT USE THIS!! This option is kept for very specific usage.
:param int buf_size:
The size of the internal buffer. (default: 256*1024)
Useful if serialisation size can be correctly estimated,
avoid unnecessary reallocations.
"""
cdef msgpack_packer pk
cdef object _default
cdef size_t exports # number of exported buffers
cdef bint strict_types
cdef bint use_float
cdef bint autoreset
cdef bint datetime
cdef object _bencoding
cdef object _berrors
cdef const char *encoding
cdef const char *unicode_errors
def __cinit__(self, buf_size=256*1024, **_kwargs):
self.pk.buf = PyMem_Malloc(buf_size)
if self.pk.buf == NULL:
raise MemoryError("Unable to allocate internal buffer.")
self.pk.buf_size = buf_size
self.pk.length = 0
self.exports = 0
def __dealloc__(self):
PyMem_Free(self.pk.buf)
self.pk.buf = NULL
assert self.exports == 0
cdef _check_exports(self):
if self.exports > 0:
raise BufferError("Existing exports of data: Packer cannot be changed")
def __init__(self, *, default=None, encoding=None,
bint use_single_float=False, bint autoreset=True, bint use_bin_type=False,
bint strict_types=False, bint datetime=False, unicode_errors=None,
buf_size=256*1024):
self.use_float = use_single_float
self.strict_types = strict_types
self.autoreset = autoreset
self.datetime = datetime
self.pk.use_bin_type = use_bin_type
if default is not None:
if not PyCallable_Check(default):
raise TypeError("default must be a callable.")
self._default = default
if encoding is None:
if PY_MAJOR_VERSION < 3:
encoding = 'utf-8'
if encoding is None:
self._bencoding = None
self.encoding = NULL
else:
self._bencoding = ensure_bytes(encoding)
self.encoding = self._bencoding
else:
self._bencoding = ensure_bytes(encoding)
self.encoding = self._bencoding
unicode_errors = ensure_bytes(unicode_errors)
self._berrors = unicode_errors
if unicode_errors is None:
self.unicode_errors = NULL
else:
self.unicode_errors = self._berrors
# returns -2 when default should(o) be called
cdef int _pack_inner(self, object o, bint will_default, int nest_limit) except -1:
cdef long long llval
cdef unsigned long long ullval
cdef unsigned long ulval
cdef const char* rawval
cdef Py_ssize_t L
cdef Py_buffer view
cdef bint strict = self.strict_types
if o is None:
msgpack_pack_nil(&self.pk)
elif o is True:
msgpack_pack_true(&self.pk)
elif o is False:
msgpack_pack_false(&self.pk)
elif PyLong_CheckExact(o) if strict else PyLong_Check(o):
try:
if o > 0:
ullval = o
msgpack_pack_unsigned_long_long(&self.pk, ullval)
else:
llval = o
msgpack_pack_long_long(&self.pk, llval)
except OverflowError as oe:
if will_default:
return -2
else:
raise OverflowError("Integer value out of range")
elif PyFloat_CheckExact(o) if strict else PyFloat_Check(o):
if self.use_float:
msgpack_pack_float(&self.pk, o)
else:
msgpack_pack_double(&self.pk, o)
elif PyBytesLike_CheckExact(o) if strict else PyBytesLike_Check(o):
L = Py_SIZE(o)
if L > ITEM_LIMIT:
PyErr_Format(ValueError, b"%.200s object is too large", Py_TYPE(o).tp_name)
rawval = o
msgpack_pack_bin(&self.pk, L)
msgpack_pack_raw_body(&self.pk, rawval, L)
elif PyUnicode_CheckExact(o) if strict else PyUnicode_Check(o):
if self.unicode_errors == NULL and self.encoding == NULL:
rawval = PyUnicode_AsUTF8AndSize(o, &L)
if L >ITEM_LIMIT:
raise ValueError("unicode string is too large")
else:
o = PyUnicode_AsEncodedString(o, self.encoding, self.unicode_errors)
L = Py_SIZE(o)
if L > ITEM_LIMIT:
raise ValueError("unicode string is too large")
rawval = o
msgpack_pack_raw(&self.pk, L)
msgpack_pack_raw_body(&self.pk, rawval, L)
elif PyDict_CheckExact(o) if strict else PyDict_Check(o):
L = len(o)
if L > ITEM_LIMIT:
raise ValueError("dict is too large")
msgpack_pack_map(&self.pk, L)
for k, v in o.items():
self._pack(k, nest_limit)
self._pack(v, nest_limit)
elif type(o) is ExtType if strict else isinstance(o, ExtType):
# This should be before Tuple because ExtType is namedtuple.
rawval = o.data
L = len(o.data)
if L > ITEM_LIMIT:
raise ValueError("EXT data is too large")
msgpack_pack_ext(&self.pk, o.code, L)
msgpack_pack_raw_body(&self.pk, rawval, L)
elif type(o) is Timestamp:
llval = o.seconds
ulval = o.nanoseconds
msgpack_pack_timestamp(&self.pk, llval, ulval)
elif PyList_CheckExact(o) if strict else (PyTuple_Check(o) or PyList_Check(o)):
L = Py_SIZE(o)
if L > ITEM_LIMIT:
raise ValueError("list is too large")
msgpack_pack_array(&self.pk, L)
for v in o:
self._pack(v, nest_limit)
elif PyMemoryView_Check(o):
PyObject_GetBuffer(o, &view, PyBUF_SIMPLE)
L = view.len
if L > ITEM_LIMIT:
PyBuffer_Release(&view);
raise ValueError("memoryview is too large")
try:
msgpack_pack_bin(&self.pk, L)
msgpack_pack_raw_body(&self.pk, view.buf, L)
finally:
PyBuffer_Release(&view);
elif self.datetime and PyDateTime_CheckExact(o) and datetime_tzinfo(o) is not None:
delta = o - epoch
if not PyDelta_CheckExact(delta):
raise ValueError("failed to calculate delta")
llval = timedelta_days(delta) * (24*60*60) + timedelta_seconds(delta)
ulval = timedelta_microseconds(delta) * 1000
msgpack_pack_timestamp(&self.pk, llval, ulval)
elif will_default:
return -2
elif self.datetime and PyDateTime_CheckExact(o):
# this should be later than will_default
PyErr_Format(ValueError, b"can not serialize '%.200s' object where tzinfo=None", Py_TYPE(o).tp_name)
else:
PyErr_Format(TypeError, b"can not serialize '%.200s' object", Py_TYPE(o).tp_name)
cdef int _pack(self, object o, int nest_limit=DEFAULT_RECURSE_LIMIT) except -1:
cdef int ret
if nest_limit < 0:
raise ValueError("recursion limit exceeded.")
nest_limit -= 1
if self._default is not None:
ret = self._pack_inner(o, 1, nest_limit)
if ret == -2:
o = self._default(o)
else:
return ret
return self._pack_inner(o, 0, nest_limit)
def pack(self, object obj):
cdef int ret
self._check_exports()
try:
ret = self._pack(obj, DEFAULT_RECURSE_LIMIT)
except:
self.pk.length = 0
raise
if ret: # should not happen.
raise RuntimeError("internal error")
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0
return buf
def pack_ext_type(self, typecode, data):
self._check_exports()
if len(data) > ITEM_LIMIT:
raise ValueError("ext data too large")
msgpack_pack_ext(&self.pk, typecode, len(data))
msgpack_pack_raw_body(&self.pk, data, len(data))
def pack_array_header(self, long long size):
self._check_exports()
if size > ITEM_LIMIT:
raise ValueError("array too large")
msgpack_pack_array(&self.pk, size)
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0
return buf
def pack_map_header(self, long long size):
self._check_exports()
if size > ITEM_LIMIT:
raise ValueError("map too learge")
msgpack_pack_map(&self.pk, size)
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0
return buf
def pack_map_pairs(self, object pairs):
"""
Pack *pairs* as msgpack map type.
*pairs* should be a sequence of pairs.
(`len(pairs)` and `for k, v in pairs:` should be supported.)
"""
self._check_exports()
size = len(pairs)
if size > ITEM_LIMIT:
raise ValueError("map too large")
msgpack_pack_map(&self.pk, size)
for k, v in pairs:
self._pack(k)
self._pack(v)
if self.autoreset:
buf = PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
self.pk.length = 0
return buf
def reset(self):
"""Reset internal buffer.
This method is useful only when autoreset=False.
"""
self._check_exports()
self.pk.length = 0
def bytes(self):
"""Return internal buffer contents as bytes object"""
return PyBytes_FromStringAndSize(self.pk.buf, self.pk.length)
def getbuffer(self):
"""Return memoryview of internal buffer.
Note: Packer now supports buffer protocol. You can use memoryview(packer).
"""
return memoryview(self)
def __getbuffer__(self, Py_buffer *buffer, int flags):
PyBuffer_FillInfo(buffer, self, self.pk.buf, self.pk.length, 1, flags)
self.exports += 1
def __releasebuffer__(self, Py_buffer *buffer):
self.exports -= 1
srsly-release-v2.5.1/srsly/msgpack/_unpacker.pyx 0000664 0000000 0000000 00000045562 14742310675 0022036 0 ustar 00root root 0000000 0000000 # coding: utf-8
from cpython cimport *
cdef extern from "Python.h":
ctypedef struct PyObject
object PyMemoryView_GetContiguous(object obj, int buffertype, char order)
from libc.stdlib cimport *
from libc.string cimport *
from libc.limits cimport *
from libc.stdint cimport uint64_t
from .exceptions import (
BufferFull,
OutOfData,
ExtraData,
FormatError,
StackError,
)
from .ext import ExtType, Timestamp
from .util import ensure_bytes
from ._epoch import utc, epoch
cdef object giga = 1_000_000_000
cdef extern from "unpack.h":
ctypedef struct msgpack_user:
bint use_list
bint raw
bint has_pairs_hook # call object_hook with k-v pairs
bint strict_map_key
int timestamp
PyObject* object_hook
PyObject* list_hook
PyObject* ext_hook
PyObject* timestamp_t
PyObject *giga;
PyObject *utc;
const char *unicode_errors
const char *encoding
Py_ssize_t max_str_len
Py_ssize_t max_bin_len
Py_ssize_t max_array_len
Py_ssize_t max_map_len
Py_ssize_t max_ext_len
ctypedef struct unpack_context:
msgpack_user user
PyObject* obj
Py_ssize_t count
ctypedef int (*execute_fn)(unpack_context* ctx, const char* data,
Py_ssize_t len, Py_ssize_t* off) except? -1
execute_fn unpack_construct
execute_fn unpack_skip
execute_fn read_array_header
execute_fn read_map_header
void unpack_init(unpack_context* ctx)
object unpack_data(unpack_context* ctx)
void unpack_clear(unpack_context* ctx)
cdef inline init_ctx(unpack_context *ctx,
object object_hook, object object_pairs_hook,
object list_hook, object ext_hook,
bint use_list, bint raw, int timestamp,
bint strict_map_key,
const char* encoding,
const char* unicode_errors,
Py_ssize_t max_str_len, Py_ssize_t max_bin_len,
Py_ssize_t max_array_len, Py_ssize_t max_map_len,
Py_ssize_t max_ext_len):
unpack_init(ctx)
ctx.user.use_list = use_list
ctx.user.raw = raw
ctx.user.strict_map_key = strict_map_key
ctx.user.object_hook = ctx.user.list_hook = NULL
ctx.user.max_str_len = max_str_len
ctx.user.max_bin_len = max_bin_len
ctx.user.max_array_len = max_array_len
ctx.user.max_map_len = max_map_len
ctx.user.max_ext_len = max_ext_len
if object_hook is not None and object_pairs_hook is not None:
raise TypeError("object_pairs_hook and object_hook are mutually exclusive.")
if object_hook is not None:
if not PyCallable_Check(object_hook):
raise TypeError("object_hook must be a callable.")
ctx.user.object_hook = object_hook
if object_pairs_hook is None:
ctx.user.has_pairs_hook = False
else:
if not PyCallable_Check(object_pairs_hook):
raise TypeError("object_pairs_hook must be a callable.")
ctx.user.object_hook = object_pairs_hook
ctx.user.has_pairs_hook = True
if list_hook is not None:
if not PyCallable_Check(list_hook):
raise TypeError("list_hook must be a callable.")
ctx.user.list_hook = list_hook
if ext_hook is not None:
if not PyCallable_Check(ext_hook):
raise TypeError("ext_hook must be a callable.")
ctx.user.ext_hook = ext_hook
if timestamp < 0 or 3 < timestamp:
raise ValueError("timestamp must be 0..3")
# Add Timestamp type to the user object so it may be used in unpack.h
ctx.user.timestamp = timestamp
ctx.user.timestamp_t = Timestamp
ctx.user.giga = giga
ctx.user.utc = utc
ctx.user.unicode_errors = unicode_errors
ctx.user.encoding = encoding
def default_read_extended_type(typecode, data):
raise NotImplementedError("Cannot decode extended type with typecode=%d" % typecode)
cdef inline int get_data_from_buffer(object obj,
Py_buffer *view,
char **buf,
Py_ssize_t *buffer_len) except 0:
cdef object contiguous
cdef Py_buffer tmp
if PyObject_GetBuffer(obj, view, PyBUF_FULL_RO) == -1:
raise
if view.itemsize != 1:
PyBuffer_Release(view)
raise BufferError("cannot unpack from multi-byte object")
if PyBuffer_IsContiguous(view, b'A') == 0:
PyBuffer_Release(view)
# create a contiguous copy and get buffer
contiguous = PyMemoryView_GetContiguous(obj, PyBUF_READ, b'C')
PyObject_GetBuffer(contiguous, view, PyBUF_SIMPLE)
# view must hold the only reference to contiguous,
# so memory is freed when view is released
Py_DECREF(contiguous)
buffer_len[0] = view.len
buf[0] = view.buf
return 1
def unpackb(object packed, *, object object_hook=None, object list_hook=None,
bint use_list=True, bint raw=True, int timestamp=0, bint strict_map_key=False,
encoding=None,
unicode_errors=None,
object_pairs_hook=None, ext_hook=ExtType,
Py_ssize_t max_str_len=-1,
Py_ssize_t max_bin_len=-1,
Py_ssize_t max_array_len=-1,
Py_ssize_t max_map_len=-1,
Py_ssize_t max_ext_len=-1):
"""
Unpack packed_bytes to object. Returns an unpacked object.
Raises ``ExtraData`` when *packed* contains extra bytes.
Raises ``ValueError`` when *packed* is incomplete.
Raises ``FormatError`` when *packed* is not valid msgpack.
Raises ``StackError`` when *packed* contains too nested.
Other exceptions can be raised during unpacking.
See :class:`Unpacker` for options.
*max_xxx_len* options are configured automatically from ``len(packed)``.
"""
cdef unpack_context ctx
cdef Py_ssize_t off = 0
cdef int ret
cdef Py_buffer view
cdef char* buf = NULL
cdef Py_ssize_t buf_len
cdef const char* cenc = NULL
cdef const char* cerr = NULL
if encoding is not None:
encoding = ensure_bytes(encoding)
cenc = encoding
if unicode_errors is not None:
unicode_errors = ensure_bytes(unicode_errors)
cerr = unicode_errors
get_data_from_buffer(packed, &view, &buf, &buf_len)
if max_str_len == -1:
max_str_len = buf_len
if max_bin_len == -1:
max_bin_len = buf_len
if max_array_len == -1:
max_array_len = buf_len
if max_map_len == -1:
max_map_len = buf_len//2
if max_ext_len == -1:
max_ext_len = buf_len
try:
init_ctx(&ctx, object_hook, object_pairs_hook, list_hook, ext_hook,
use_list, raw, timestamp, strict_map_key, cenc, cerr,
max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len)
ret = unpack_construct(&ctx, buf, buf_len, &off)
finally:
PyBuffer_Release(&view);
if ret == 1:
obj = unpack_data(&ctx)
if off < buf_len:
raise ExtraData(obj, PyBytes_FromStringAndSize(buf+off, buf_len-off))
return obj
unpack_clear(&ctx)
if ret == 0:
raise ValueError("Unpack failed: incomplete input")
elif ret == -2:
raise FormatError
elif ret == -3:
raise StackError
raise ValueError("Unpack failed: error = %d" % (ret,))
cdef class Unpacker:
"""Streaming unpacker.
Arguments:
:param file_like:
File-like object having `.read(n)` method.
If specified, unpacker reads serialized data from it and `.feed()` is not usable.
:param int read_size:
Used as `file_like.read(read_size)`. (default: `min(16*1024, max_buffer_size)`)
:param bool use_list:
If true, unpack msgpack array to Python list.
Otherwise, unpack to Python tuple. (default: True)
:param bool raw:
If true, unpack msgpack raw to Python bytes.
Otherwise, unpack to Python str by decoding with UTF-8 encoding (default).
:param int timestamp:
Control how timestamp type is unpacked:
0 - Timestamp
1 - float (Seconds from the EPOCH)
2 - int (Nanoseconds from the EPOCH)
3 - datetime.datetime (UTC).
:param bool strict_map_key:
If true (default), only str or bytes are accepted for map (dict) keys.
:param object_hook:
When specified, it should be callable.
Unpacker calls it with a dict argument after unpacking msgpack map.
(See also simplejson)
:param object_pairs_hook:
When specified, it should be callable.
Unpacker calls it with a list of key-value pairs after unpacking msgpack map.
(See also simplejson)
:param str unicode_errors:
The error handler for decoding unicode. (default: 'strict')
This option should be used only when you have msgpack data which
contains invalid UTF-8 string.
:param int max_buffer_size:
Limits size of data waiting unpacked. 0 means 2**32-1.
The default value is 100*1024*1024 (100MiB).
Raises `BufferFull` exception when it is insufficient.
You should set this parameter when unpacking data from untrusted source.
:param int max_str_len:
Deprecated, use *max_buffer_size* instead.
Limits max length of str. (default: max_buffer_size)
:param int max_bin_len:
Deprecated, use *max_buffer_size* instead.
Limits max length of bin. (default: max_buffer_size)
:param int max_array_len:
Limits max length of array.
(default: max_buffer_size)
:param int max_map_len:
Limits max length of map.
(default: max_buffer_size//2)
:param int max_ext_len:
Deprecated, use *max_buffer_size* instead.
Limits max size of ext type. (default: max_buffer_size)
Example of streaming deserialize from file-like object::
unpacker = Unpacker(file_like)
for o in unpacker:
process(o)
Example of streaming deserialize from socket::
unpacker = Unpacker()
while True:
buf = sock.recv(1024**2)
if not buf:
break
unpacker.feed(buf)
for o in unpacker:
process(o)
Raises ``ExtraData`` when *packed* contains extra bytes.
Raises ``OutOfData`` when *packed* is incomplete.
Raises ``FormatError`` when *packed* is not valid msgpack.
Raises ``StackError`` when *packed* contains too nested.
Other exceptions can be raised during unpacking.
"""
cdef unpack_context ctx
cdef char* buf
cdef Py_ssize_t buf_size, buf_head, buf_tail
cdef object file_like
cdef object file_like_read
cdef Py_ssize_t read_size
# To maintain refcnt.
cdef object object_hook, object_pairs_hook, list_hook, ext_hook
cdef object unicode_errors
cdef Py_ssize_t max_buffer_size
cdef uint64_t stream_offset
def __cinit__(self):
self.buf = NULL
def __dealloc__(self):
PyMem_Free(self.buf)
self.buf = NULL
def __init__(self, file_like=None, *, Py_ssize_t read_size=0,
bint use_list=True, bint raw=True, int timestamp=0, bint strict_map_key=False,
object object_hook=None, object object_pairs_hook=None, object list_hook=None,
unicode_errors=None, Py_ssize_t max_buffer_size=100*1024*1024,
object ext_hook=ExtType,
Py_ssize_t max_str_len=-1,
Py_ssize_t max_bin_len=-1,
Py_ssize_t max_array_len=-1,
Py_ssize_t max_map_len=-1,
Py_ssize_t max_ext_len=-1):
cdef const char* cerr = NULL
cdef const char* cenc = NULL
self.object_hook = object_hook
self.object_pairs_hook = object_pairs_hook
self.list_hook = list_hook
self.ext_hook = ext_hook
self.file_like = file_like
if file_like:
self.file_like_read = file_like.read
if not PyCallable_Check(self.file_like_read):
raise TypeError("`file_like.read` must be a callable.")
if not max_buffer_size:
max_buffer_size = INT_MAX
if max_str_len == -1:
max_str_len = max_buffer_size
if max_bin_len == -1:
max_bin_len = max_buffer_size
if max_array_len == -1:
max_array_len = max_buffer_size
if max_map_len == -1:
max_map_len = max_buffer_size//2
if max_ext_len == -1:
max_ext_len = max_buffer_size
if read_size > max_buffer_size:
raise ValueError("read_size should be less or equal to max_buffer_size")
if not read_size:
read_size = min(max_buffer_size, 1024**2)
self.max_buffer_size = max_buffer_size
self.read_size = read_size
self.buf = PyMem_Malloc(read_size)
if self.buf == NULL:
raise MemoryError("Unable to allocate internal buffer.")
self.buf_size = read_size
self.buf_head = 0
self.buf_tail = 0
self.stream_offset = 0
if unicode_errors is not None:
self.unicode_errors = unicode_errors
cerr = unicode_errors
init_ctx(&self.ctx, object_hook, object_pairs_hook, list_hook,
ext_hook, use_list, raw, timestamp, strict_map_key, cenc, cerr,
max_str_len, max_bin_len, max_array_len,
max_map_len, max_ext_len)
def feed(self, object next_bytes):
"""Append `next_bytes` to internal buffer."""
cdef Py_buffer pybuff
cdef char* buf
cdef Py_ssize_t buf_len
if self.file_like is not None:
raise AssertionError(
"unpacker.feed() is not be able to use with `file_like`.")
get_data_from_buffer(next_bytes, &pybuff, &buf, &buf_len)
try:
self.append_buffer(buf, buf_len)
finally:
PyBuffer_Release(&pybuff)
cdef append_buffer(self, void* _buf, Py_ssize_t _buf_len):
cdef:
char* buf = self.buf
char* new_buf
Py_ssize_t head = self.buf_head
Py_ssize_t tail = self.buf_tail
Py_ssize_t buf_size = self.buf_size
Py_ssize_t new_size
if tail + _buf_len > buf_size:
if ((tail - head) + _buf_len) <= buf_size:
# move to front.
memmove(buf, buf + head, tail - head)
tail -= head
head = 0
else:
# expand buffer.
new_size = (tail-head) + _buf_len
if new_size > self.max_buffer_size:
raise BufferFull
new_size = min(new_size*2, self.max_buffer_size)
new_buf = PyMem_Malloc(new_size)
if new_buf == NULL:
# self.buf still holds old buffer and will be freed during
# obj destruction
raise MemoryError("Unable to enlarge internal buffer.")
memcpy(new_buf, buf + head, tail - head)
PyMem_Free(buf)
buf = new_buf
buf_size = new_size
tail -= head
head = 0
memcpy(buf + tail, (_buf), _buf_len)
self.buf = buf
self.buf_head = head
self.buf_size = buf_size
self.buf_tail = tail + _buf_len
cdef int read_from_file(self) except -1:
cdef Py_ssize_t remains = self.max_buffer_size - (self.buf_tail - self.buf_head)
if remains <= 0:
raise BufferFull
next_bytes = self.file_like_read(min(self.read_size, remains))
if next_bytes:
self.append_buffer(PyBytes_AsString(next_bytes), PyBytes_Size(next_bytes))
else:
self.file_like = None
return 0
cdef object _unpack(self, execute_fn execute, bint iter=0):
cdef int ret
cdef object obj
cdef Py_ssize_t prev_head
while 1:
prev_head = self.buf_head
if prev_head < self.buf_tail:
ret = execute(&self.ctx, self.buf, self.buf_tail, &self.buf_head)
self.stream_offset += self.buf_head - prev_head
else:
ret = 0
if ret == 1:
obj = unpack_data(&self.ctx)
unpack_init(&self.ctx)
return obj
elif ret == 0:
if self.file_like is not None:
self.read_from_file()
continue
if iter:
raise StopIteration("No more data to unpack.")
else:
raise OutOfData("No more data to unpack.")
elif ret == -2:
raise FormatError
elif ret == -3:
raise StackError
else:
raise ValueError("Unpack failed: error = %d" % (ret,))
def read_bytes(self, Py_ssize_t nbytes):
"""Read a specified number of raw bytes from the stream"""
cdef Py_ssize_t nread
nread = min(self.buf_tail - self.buf_head, nbytes)
ret = PyBytes_FromStringAndSize(self.buf + self.buf_head, nread)
self.buf_head += nread
if nread < nbytes and self.file_like is not None:
ret += self.file_like.read(nbytes - nread)
nread = len(ret)
self.stream_offset += nread
return ret
def unpack(self):
"""Unpack one object
Raises `OutOfData` when there are no more bytes to unpack.
"""
return self._unpack(unpack_construct)
def skip(self):
"""Read and ignore one object, returning None
Raises `OutOfData` when there are no more bytes to unpack.
"""
return self._unpack(unpack_skip)
def read_array_header(self):
"""assuming the next object is an array, return its size n, such that
the next n unpack() calls will iterate over its contents.
Raises `OutOfData` when there are no more bytes to unpack.
"""
return self._unpack(read_array_header)
def read_map_header(self):
"""assuming the next object is a map, return its size n, such that the
next n * 2 unpack() calls will iterate over its key-value pairs.
Raises `OutOfData` when there are no more bytes to unpack.
"""
return self._unpack(read_map_header)
def tell(self):
"""Returns the current position of the Unpacker in bytes, i.e., the
number of bytes that were read from the input, also the starting
position of the next object.
"""
return self.stream_offset
def __iter__(self):
return self
def __next__(self):
return self._unpack(unpack_construct, 1)
# for debug.
#def _buf(self):
# return PyString_FromStringAndSize(self.buf, self.buf_tail)
#def _off(self):
# return self.buf_head
srsly-release-v2.5.1/srsly/msgpack/_version.py 0000664 0000000 0000000 00000000024 14742310675 0021503 0 ustar 00root root 0000000 0000000 version = (1, 1, 0)
srsly-release-v2.5.1/srsly/msgpack/exceptions.py 0000664 0000000 0000000 00000002071 14742310675 0022044 0 ustar 00root root 0000000 0000000 class UnpackException(Exception):
"""Base class for some exceptions raised while unpacking.
NOTE: unpack may raise exception other than subclass of
UnpackException. If you want to catch all error, catch
Exception instead.
"""
class BufferFull(UnpackException):
pass
class OutOfData(UnpackException):
pass
class FormatError(ValueError, UnpackException):
"""Invalid msgpack format"""
class StackError(ValueError, UnpackException):
"""Too nested"""
# Deprecated. Use ValueError instead
UnpackValueError = ValueError
class ExtraData(UnpackValueError):
"""ExtraData is raised when there is trailing data.
This exception is raised while only one-shot (not streaming)
unpack.
"""
def __init__(self, unpacked, extra):
self.unpacked = unpacked
self.extra = extra
def __str__(self):
return "unpack(b) received extra data."
# Deprecated. Use Exception instead to catch all exception during packing.
PackException = Exception
PackValueError = ValueError
PackOverflowError = OverflowError
srsly-release-v2.5.1/srsly/msgpack/ext.py 0000664 0000000 0000000 00000013136 14742310675 0020467 0 ustar 00root root 0000000 0000000 import datetime
import struct
from collections import namedtuple
class ExtType(namedtuple("ExtType", "code data")):
"""ExtType represents ext type in msgpack."""
def __new__(cls, code, data):
if not isinstance(code, int):
raise TypeError("code must be int")
if not isinstance(data, bytes):
raise TypeError("data must be bytes")
if not 0 <= code <= 127:
raise ValueError("code must be 0~127")
return super().__new__(cls, code, data)
class Timestamp:
"""Timestamp represents the Timestamp extension type in msgpack.
When built with Cython, msgpack uses C methods to pack and unpack `Timestamp`.
When using pure-Python msgpack, :func:`to_bytes` and :func:`from_bytes` are used to pack and
unpack `Timestamp`.
This class is immutable: Do not override seconds and nanoseconds.
"""
__slots__ = ["seconds", "nanoseconds"]
def __init__(self, seconds, nanoseconds=0):
"""Initialize a Timestamp object.
:param int seconds:
Number of seconds since the UNIX epoch (00:00:00 UTC Jan 1 1970, minus leap seconds).
May be negative.
:param int nanoseconds:
Number of nanoseconds to add to `seconds` to get fractional time.
Maximum is 999_999_999. Default is 0.
Note: Negative times (before the UNIX epoch) are represented as neg. seconds + pos. ns.
"""
if not isinstance(seconds, int):
raise TypeError("seconds must be an integer")
if not isinstance(nanoseconds, int):
raise TypeError("nanoseconds must be an integer")
if not (0 <= nanoseconds < 10**9):
raise ValueError("nanoseconds must be a non-negative integer less than 999999999.")
self.seconds = seconds
self.nanoseconds = nanoseconds
def __repr__(self):
"""String representation of Timestamp."""
return f"Timestamp(seconds={self.seconds}, nanoseconds={self.nanoseconds})"
def __eq__(self, other):
"""Check for equality with another Timestamp object"""
if type(other) is self.__class__:
return self.seconds == other.seconds and self.nanoseconds == other.nanoseconds
return False
def __ne__(self, other):
"""not-equals method (see :func:`__eq__()`)"""
return not self.__eq__(other)
def __hash__(self):
return hash((self.seconds, self.nanoseconds))
@staticmethod
def from_bytes(b):
"""Unpack bytes into a `Timestamp` object.
Used for pure-Python msgpack unpacking.
:param b: Payload from msgpack ext message with code -1
:type b: bytes
:returns: Timestamp object unpacked from msgpack ext payload
:rtype: Timestamp
"""
if len(b) == 4:
seconds = struct.unpack("!L", b)[0]
nanoseconds = 0
elif len(b) == 8:
data64 = struct.unpack("!Q", b)[0]
seconds = data64 & 0x00000003FFFFFFFF
nanoseconds = data64 >> 34
elif len(b) == 12:
nanoseconds, seconds = struct.unpack("!Iq", b)
else:
raise ValueError(
"Timestamp type can only be created from 32, 64, or 96-bit byte objects"
)
return Timestamp(seconds, nanoseconds)
def to_bytes(self):
"""Pack this Timestamp object into bytes.
Used for pure-Python msgpack packing.
:returns data: Payload for EXT message with code -1 (timestamp type)
:rtype: bytes
"""
if (self.seconds >> 34) == 0: # seconds is non-negative and fits in 34 bits
data64 = self.nanoseconds << 34 | self.seconds
if data64 & 0xFFFFFFFF00000000 == 0:
# nanoseconds is zero and seconds < 2**32, so timestamp 32
data = struct.pack("!L", data64)
else:
# timestamp 64
data = struct.pack("!Q", data64)
else:
# timestamp 96
data = struct.pack("!Iq", self.nanoseconds, self.seconds)
return data
@staticmethod
def from_unix(unix_sec):
"""Create a Timestamp from posix timestamp in seconds.
:param unix_float: Posix timestamp in seconds.
:type unix_float: int or float
"""
seconds = int(unix_sec // 1)
nanoseconds = int((unix_sec % 1) * 10**9)
return Timestamp(seconds, nanoseconds)
def to_unix(self):
"""Get the timestamp as a floating-point value.
:returns: posix timestamp
:rtype: float
"""
return self.seconds + self.nanoseconds / 1e9
@staticmethod
def from_unix_nano(unix_ns):
"""Create a Timestamp from posix timestamp in nanoseconds.
:param int unix_ns: Posix timestamp in nanoseconds.
:rtype: Timestamp
"""
return Timestamp(*divmod(unix_ns, 10**9))
def to_unix_nano(self):
"""Get the timestamp as a unixtime in nanoseconds.
:returns: posix timestamp in nanoseconds
:rtype: int
"""
return self.seconds * 10**9 + self.nanoseconds
def to_datetime(self):
"""Get the timestamp as a UTC datetime.
:rtype: `datetime.datetime`
"""
utc = datetime.timezone.utc
return datetime.datetime.fromtimestamp(0, utc) + datetime.timedelta(
seconds=self.seconds, microseconds=self.nanoseconds // 1000
)
@staticmethod
def from_datetime(dt):
"""Create a Timestamp from datetime with tzinfo.
:rtype: Timestamp
"""
return Timestamp(seconds=int(dt.timestamp()), nanoseconds=dt.microsecond * 1000)
srsly-release-v2.5.1/srsly/msgpack/fallback.py 0000664 0000000 0000000 00000077206 14742310675 0021436 0 ustar 00root root 0000000 0000000 """Fallback pure Python implementation of msgpack"""
import struct
import sys
from datetime import datetime as _DateTime
if hasattr(sys, "pypy_version_info"):
from __pypy__ import newlist_hint
from __pypy__.builders import BytesBuilder
_USING_STRINGBUILDER = True
class BytesIO:
def __init__(self, s=b""):
if s:
self.builder = BytesBuilder(len(s))
self.builder.append(s)
else:
self.builder = BytesBuilder()
def write(self, s):
if isinstance(s, memoryview):
s = s.tobytes()
elif isinstance(s, bytearray):
s = bytes(s)
self.builder.append(s)
def getvalue(self):
return self.builder.build()
else:
from io import BytesIO
_USING_STRINGBUILDER = False
def newlist_hint(size):
return []
from .exceptions import BufferFull, ExtraData, FormatError, OutOfData, StackError
from .ext import ExtType, Timestamp
EX_SKIP = 0
EX_CONSTRUCT = 1
EX_READ_ARRAY_HEADER = 2
EX_READ_MAP_HEADER = 3
TYPE_IMMEDIATE = 0
TYPE_ARRAY = 1
TYPE_MAP = 2
TYPE_RAW = 3
TYPE_BIN = 4
TYPE_EXT = 5
DEFAULT_RECURSE_LIMIT = 511
def _check_type_strict(obj, t, type=type, tuple=tuple):
if type(t) is tuple:
return type(obj) in t
else:
return type(obj) is t
def _get_data_from_buffer(obj):
view = memoryview(obj)
if view.itemsize != 1:
raise ValueError("cannot unpack from multi-byte object")
return view
def unpackb(packed, **kwargs):
"""
Unpack an object from `packed`.
Raises ``ExtraData`` when *packed* contains extra bytes.
Raises ``ValueError`` when *packed* is incomplete.
Raises ``FormatError`` when *packed* is not valid msgpack.
Raises ``StackError`` when *packed* contains too nested.
Other exceptions can be raised during unpacking.
See :class:`Unpacker` for options.
"""
unpacker = Unpacker(None, max_buffer_size=len(packed), **kwargs)
unpacker.feed(packed)
try:
ret = unpacker._unpack()
except OutOfData:
raise ValueError("Unpack failed: incomplete input")
except RecursionError:
raise StackError
if unpacker._got_extradata():
raise ExtraData(ret, unpacker._get_extradata())
return ret
_NO_FORMAT_USED = ""
_MSGPACK_HEADERS = {
0xC4: (1, _NO_FORMAT_USED, TYPE_BIN),
0xC5: (2, ">H", TYPE_BIN),
0xC6: (4, ">I", TYPE_BIN),
0xC7: (2, "Bb", TYPE_EXT),
0xC8: (3, ">Hb", TYPE_EXT),
0xC9: (5, ">Ib", TYPE_EXT),
0xCA: (4, ">f"),
0xCB: (8, ">d"),
0xCC: (1, _NO_FORMAT_USED),
0xCD: (2, ">H"),
0xCE: (4, ">I"),
0xCF: (8, ">Q"),
0xD0: (1, "b"),
0xD1: (2, ">h"),
0xD2: (4, ">i"),
0xD3: (8, ">q"),
0xD4: (1, "b1s", TYPE_EXT),
0xD5: (2, "b2s", TYPE_EXT),
0xD6: (4, "b4s", TYPE_EXT),
0xD7: (8, "b8s", TYPE_EXT),
0xD8: (16, "b16s", TYPE_EXT),
0xD9: (1, _NO_FORMAT_USED, TYPE_RAW),
0xDA: (2, ">H", TYPE_RAW),
0xDB: (4, ">I", TYPE_RAW),
0xDC: (2, ">H", TYPE_ARRAY),
0xDD: (4, ">I", TYPE_ARRAY),
0xDE: (2, ">H", TYPE_MAP),
0xDF: (4, ">I", TYPE_MAP),
}
class Unpacker:
"""Streaming unpacker.
Arguments:
:param file_like:
File-like object having `.read(n)` method.
If specified, unpacker reads serialized data from it and `.feed()` is not usable.
:param int read_size:
Used as `file_like.read(read_size)`. (default: `min(16*1024, max_buffer_size)`)
:param bool use_list:
If true, unpack msgpack array to Python list.
Otherwise, unpack to Python tuple. (default: True)
:param bool raw:
If true, unpack msgpack raw to Python bytes.
Otherwise, unpack to Python str by decoding with UTF-8 encoding (default).
:param int timestamp:
Control how timestamp type is unpacked:
0 - Timestamp
1 - float (Seconds from the EPOCH)
2 - int (Nanoseconds from the EPOCH)
3 - datetime.datetime (UTC).
:param bool strict_map_key:
If true (default), only str or bytes are accepted for map (dict) keys.
:param object_hook:
When specified, it should be callable.
Unpacker calls it with a dict argument after unpacking msgpack map.
(See also simplejson)
:param object_pairs_hook:
When specified, it should be callable.
Unpacker calls it with a list of key-value pairs after unpacking msgpack map.
(See also simplejson)
:param str unicode_errors:
The error handler for decoding unicode. (default: 'strict')
This option should be used only when you have msgpack data which
contains invalid UTF-8 string.
:param int max_buffer_size:
Limits size of data waiting unpacked. 0 means 2**32-1.
The default value is 100*1024*1024 (100MiB).
Raises `BufferFull` exception when it is insufficient.
You should set this parameter when unpacking data from untrusted source.
:param int max_str_len:
Deprecated, use *max_buffer_size* instead.
Limits max length of str. (default: max_buffer_size)
:param int max_bin_len:
Deprecated, use *max_buffer_size* instead.
Limits max length of bin. (default: max_buffer_size)
:param int max_array_len:
Limits max length of array.
(default: max_buffer_size)
:param int max_map_len:
Limits max length of map.
(default: max_buffer_size//2)
:param int max_ext_len:
Deprecated, use *max_buffer_size* instead.
Limits max size of ext type. (default: max_buffer_size)
Example of streaming deserialize from file-like object::
unpacker = Unpacker(file_like)
for o in unpacker:
process(o)
Example of streaming deserialize from socket::
unpacker = Unpacker()
while True:
buf = sock.recv(1024**2)
if not buf:
break
unpacker.feed(buf)
for o in unpacker:
process(o)
Raises ``ExtraData`` when *packed* contains extra bytes.
Raises ``OutOfData`` when *packed* is incomplete.
Raises ``FormatError`` when *packed* is not valid msgpack.
Raises ``StackError`` when *packed* contains too nested.
Other exceptions can be raised during unpacking.
"""
def __init__(
self,
file_like=None,
*,
read_size=0,
use_list=True,
raw=False,
timestamp=0,
strict_map_key=True,
object_hook=None,
object_pairs_hook=None,
list_hook=None,
unicode_errors=None,
max_buffer_size=100 * 1024 * 1024,
ext_hook=ExtType,
max_str_len=-1,
max_bin_len=-1,
max_array_len=-1,
max_map_len=-1,
max_ext_len=-1,
):
if unicode_errors is None:
unicode_errors = "strict"
if file_like is None:
self._feeding = True
else:
if not callable(file_like.read):
raise TypeError("`file_like.read` must be callable")
self.file_like = file_like
self._feeding = False
#: array of bytes fed.
self._buffer = bytearray()
#: Which position we currently reads
self._buff_i = 0
# When Unpacker is used as an iterable, between the calls to next(),
# the buffer is not "consumed" completely, for efficiency sake.
# Instead, it is done sloppily. To make sure we raise BufferFull at
# the correct moments, we have to keep track of how sloppy we were.
# Furthermore, when the buffer is incomplete (that is: in the case
# we raise an OutOfData) we need to rollback the buffer to the correct
# state, which _buf_checkpoint records.
self._buf_checkpoint = 0
if not max_buffer_size:
max_buffer_size = 2**31 - 1
if max_str_len == -1:
max_str_len = max_buffer_size
if max_bin_len == -1:
max_bin_len = max_buffer_size
if max_array_len == -1:
max_array_len = max_buffer_size
if max_map_len == -1:
max_map_len = max_buffer_size // 2
if max_ext_len == -1:
max_ext_len = max_buffer_size
self._max_buffer_size = max_buffer_size
if read_size > self._max_buffer_size:
raise ValueError("read_size must be smaller than max_buffer_size")
self._read_size = read_size or min(self._max_buffer_size, 16 * 1024)
self._raw = bool(raw)
self._strict_map_key = bool(strict_map_key)
self._unicode_errors = unicode_errors
self._use_list = use_list
if not (0 <= timestamp <= 3):
raise ValueError("timestamp must be 0..3")
self._timestamp = timestamp
self._list_hook = list_hook
self._object_hook = object_hook
self._object_pairs_hook = object_pairs_hook
self._ext_hook = ext_hook
self._max_str_len = max_str_len
self._max_bin_len = max_bin_len
self._max_array_len = max_array_len
self._max_map_len = max_map_len
self._max_ext_len = max_ext_len
self._stream_offset = 0
if list_hook is not None and not callable(list_hook):
raise TypeError("`list_hook` is not callable")
if object_hook is not None and not callable(object_hook):
raise TypeError("`object_hook` is not callable")
if object_pairs_hook is not None and not callable(object_pairs_hook):
raise TypeError("`object_pairs_hook` is not callable")
if object_hook is not None and object_pairs_hook is not None:
raise TypeError("object_pairs_hook and object_hook are mutually exclusive")
if not callable(ext_hook):
raise TypeError("`ext_hook` is not callable")
def feed(self, next_bytes):
assert self._feeding
view = _get_data_from_buffer(next_bytes)
if len(self._buffer) - self._buff_i + len(view) > self._max_buffer_size:
raise BufferFull
# Strip buffer before checkpoint before reading file.
if self._buf_checkpoint > 0:
del self._buffer[: self._buf_checkpoint]
self._buff_i -= self._buf_checkpoint
self._buf_checkpoint = 0
# Use extend here: INPLACE_ADD += doesn't reliably typecast memoryview in jython
self._buffer.extend(view)
view.release()
def _consume(self):
"""Gets rid of the used parts of the buffer."""
self._stream_offset += self._buff_i - self._buf_checkpoint
self._buf_checkpoint = self._buff_i
def _got_extradata(self):
return self._buff_i < len(self._buffer)
def _get_extradata(self):
return self._buffer[self._buff_i :]
def read_bytes(self, n):
ret = self._read(n, raise_outofdata=False)
self._consume()
return ret
def _read(self, n, raise_outofdata=True):
# (int) -> bytearray
self._reserve(n, raise_outofdata=raise_outofdata)
i = self._buff_i
ret = self._buffer[i : i + n]
self._buff_i = i + len(ret)
return ret
def _reserve(self, n, raise_outofdata=True):
remain_bytes = len(self._buffer) - self._buff_i - n
# Fast path: buffer has n bytes already
if remain_bytes >= 0:
return
if self._feeding:
self._buff_i = self._buf_checkpoint
raise OutOfData
# Strip buffer before checkpoint before reading file.
if self._buf_checkpoint > 0:
del self._buffer[: self._buf_checkpoint]
self._buff_i -= self._buf_checkpoint
self._buf_checkpoint = 0
# Read from file
remain_bytes = -remain_bytes
if remain_bytes + len(self._buffer) > self._max_buffer_size:
raise BufferFull
while remain_bytes > 0:
to_read_bytes = max(self._read_size, remain_bytes)
read_data = self.file_like.read(to_read_bytes)
if not read_data:
break
assert isinstance(read_data, bytes)
self._buffer += read_data
remain_bytes -= len(read_data)
if len(self._buffer) < n + self._buff_i and raise_outofdata:
self._buff_i = 0 # rollback
raise OutOfData
def _read_header(self):
typ = TYPE_IMMEDIATE
n = 0
obj = None
self._reserve(1)
b = self._buffer[self._buff_i]
self._buff_i += 1
if b & 0b10000000 == 0:
obj = b
elif b & 0b11100000 == 0b11100000:
obj = -1 - (b ^ 0xFF)
elif b & 0b11100000 == 0b10100000:
n = b & 0b00011111
typ = TYPE_RAW
if n > self._max_str_len:
raise ValueError(f"{n} exceeds max_str_len({self._max_str_len})")
obj = self._read(n)
elif b & 0b11110000 == 0b10010000:
n = b & 0b00001111
typ = TYPE_ARRAY
if n > self._max_array_len:
raise ValueError(f"{n} exceeds max_array_len({self._max_array_len})")
elif b & 0b11110000 == 0b10000000:
n = b & 0b00001111
typ = TYPE_MAP
if n > self._max_map_len:
raise ValueError(f"{n} exceeds max_map_len({self._max_map_len})")
elif b == 0xC0:
obj = None
elif b == 0xC2:
obj = False
elif b == 0xC3:
obj = True
elif 0xC4 <= b <= 0xC6:
size, fmt, typ = _MSGPACK_HEADERS[b]
self._reserve(size)
if len(fmt) > 0:
n = struct.unpack_from(fmt, self._buffer, self._buff_i)[0]
else:
n = self._buffer[self._buff_i]
self._buff_i += size
if n > self._max_bin_len:
raise ValueError(f"{n} exceeds max_bin_len({self._max_bin_len})")
obj = self._read(n)
elif 0xC7 <= b <= 0xC9:
size, fmt, typ = _MSGPACK_HEADERS[b]
self._reserve(size)
L, n = struct.unpack_from(fmt, self._buffer, self._buff_i)
self._buff_i += size
if L > self._max_ext_len:
raise ValueError(f"{L} exceeds max_ext_len({self._max_ext_len})")
obj = self._read(L)
elif 0xCA <= b <= 0xD3:
size, fmt = _MSGPACK_HEADERS[b]
self._reserve(size)
if len(fmt) > 0:
obj = struct.unpack_from(fmt, self._buffer, self._buff_i)[0]
else:
obj = self._buffer[self._buff_i]
self._buff_i += size
elif 0xD4 <= b <= 0xD8:
size, fmt, typ = _MSGPACK_HEADERS[b]
if self._max_ext_len < size:
raise ValueError(f"{size} exceeds max_ext_len({self._max_ext_len})")
self._reserve(size + 1)
n, obj = struct.unpack_from(fmt, self._buffer, self._buff_i)
self._buff_i += size + 1
elif 0xD9 <= b <= 0xDB:
size, fmt, typ = _MSGPACK_HEADERS[b]
self._reserve(size)
if len(fmt) > 0:
(n,) = struct.unpack_from(fmt, self._buffer, self._buff_i)
else:
n = self._buffer[self._buff_i]
self._buff_i += size
if n > self._max_str_len:
raise ValueError(f"{n} exceeds max_str_len({self._max_str_len})")
obj = self._read(n)
elif 0xDC <= b <= 0xDD:
size, fmt, typ = _MSGPACK_HEADERS[b]
self._reserve(size)
(n,) = struct.unpack_from(fmt, self._buffer, self._buff_i)
self._buff_i += size
if n > self._max_array_len:
raise ValueError(f"{n} exceeds max_array_len({self._max_array_len})")
elif 0xDE <= b <= 0xDF:
size, fmt, typ = _MSGPACK_HEADERS[b]
self._reserve(size)
(n,) = struct.unpack_from(fmt, self._buffer, self._buff_i)
self._buff_i += size
if n > self._max_map_len:
raise ValueError(f"{n} exceeds max_map_len({self._max_map_len})")
else:
raise FormatError("Unknown header: 0x%x" % b)
return typ, n, obj
def _unpack(self, execute=EX_CONSTRUCT):
typ, n, obj = self._read_header()
if execute == EX_READ_ARRAY_HEADER:
if typ != TYPE_ARRAY:
raise ValueError("Expected array")
return n
if execute == EX_READ_MAP_HEADER:
if typ != TYPE_MAP:
raise ValueError("Expected map")
return n
# TODO should we eliminate the recursion?
if typ == TYPE_ARRAY:
if execute == EX_SKIP:
for i in range(n):
# TODO check whether we need to call `list_hook`
self._unpack(EX_SKIP)
return
ret = newlist_hint(n)
for i in range(n):
ret.append(self._unpack(EX_CONSTRUCT))
if self._list_hook is not None:
ret = self._list_hook(ret)
# TODO is the interaction between `list_hook` and `use_list` ok?
return ret if self._use_list else tuple(ret)
if typ == TYPE_MAP:
if execute == EX_SKIP:
for i in range(n):
# TODO check whether we need to call hooks
self._unpack(EX_SKIP)
self._unpack(EX_SKIP)
return
if self._object_pairs_hook is not None:
ret = self._object_pairs_hook(
(self._unpack(EX_CONSTRUCT), self._unpack(EX_CONSTRUCT)) for _ in range(n)
)
else:
ret = {}
for _ in range(n):
key = self._unpack(EX_CONSTRUCT)
if self._strict_map_key and type(key) not in (str, bytes):
raise ValueError("%s is not allowed for map key" % str(type(key)))
if isinstance(key, str):
key = sys.intern(key)
ret[key] = self._unpack(EX_CONSTRUCT)
if self._object_hook is not None:
ret = self._object_hook(ret)
return ret
if execute == EX_SKIP:
return
if typ == TYPE_RAW:
if self._raw:
obj = bytes(obj)
else:
obj = obj.decode("utf_8", self._unicode_errors)
return obj
if typ == TYPE_BIN:
return bytes(obj)
if typ == TYPE_EXT:
if n == -1: # timestamp
ts = Timestamp.from_bytes(bytes(obj))
if self._timestamp == 1:
return ts.to_unix()
elif self._timestamp == 2:
return ts.to_unix_nano()
elif self._timestamp == 3:
return ts.to_datetime()
else:
return ts
else:
return self._ext_hook(n, bytes(obj))
assert typ == TYPE_IMMEDIATE
return obj
def __iter__(self):
return self
def __next__(self):
try:
ret = self._unpack(EX_CONSTRUCT)
self._consume()
return ret
except OutOfData:
self._consume()
raise StopIteration
except RecursionError:
raise StackError
next = __next__
def skip(self):
self._unpack(EX_SKIP)
self._consume()
def unpack(self):
try:
ret = self._unpack(EX_CONSTRUCT)
except RecursionError:
raise StackError
self._consume()
return ret
def read_array_header(self):
ret = self._unpack(EX_READ_ARRAY_HEADER)
self._consume()
return ret
def read_map_header(self):
ret = self._unpack(EX_READ_MAP_HEADER)
self._consume()
return ret
def tell(self):
return self._stream_offset
class Packer:
"""
MessagePack Packer
Usage::
packer = Packer()
astream.write(packer.pack(a))
astream.write(packer.pack(b))
Packer's constructor has some keyword arguments:
:param default:
When specified, it should be callable.
Convert user type to builtin type that Packer supports.
See also simplejson's document.
:param bool use_single_float:
Use single precision float type for float. (default: False)
:param bool autoreset:
Reset buffer after each pack and return its content as `bytes`. (default: True).
If set this to false, use `bytes()` to get content and `.reset()` to clear buffer.
:param bool use_bin_type:
Use bin type introduced in msgpack spec 2.0 for bytes.
It also enables str8 type for unicode. (default: True)
:param bool strict_types:
If set to true, types will be checked to be exact. Derived classes
from serializable types will not be serialized and will be
treated as unsupported type and forwarded to default.
Additionally tuples will not be serialized as lists.
This is useful when trying to implement accurate serialization
for python types.
:param bool datetime:
If set to true, datetime with tzinfo is packed into Timestamp type.
Note that the tzinfo is stripped in the timestamp.
You can get UTC datetime with `timestamp=3` option of the Unpacker.
:param str unicode_errors:
The error handler for encoding unicode. (default: 'strict')
DO NOT USE THIS!! This option is kept for very specific usage.
:param int buf_size:
Internal buffer size. This option is used only for C implementation.
"""
def __init__(
self,
*,
default=None,
use_single_float=False,
autoreset=True,
use_bin_type=True,
strict_types=False,
datetime=False,
unicode_errors=None,
buf_size=None,
):
self._strict_types = strict_types
self._use_float = use_single_float
self._autoreset = autoreset
self._use_bin_type = use_bin_type
self._buffer = BytesIO()
self._datetime = bool(datetime)
self._unicode_errors = unicode_errors or "strict"
if default is not None and not callable(default):
raise TypeError("default must be callable")
self._default = default
def _pack(
self,
obj,
nest_limit=DEFAULT_RECURSE_LIMIT,
check=isinstance,
check_type_strict=_check_type_strict,
):
default_used = False
if self._strict_types:
check = check_type_strict
list_types = list
else:
list_types = (list, tuple)
while True:
if nest_limit < 0:
raise ValueError("recursion limit exceeded")
if obj is None:
return self._buffer.write(b"\xc0")
if check(obj, bool):
if obj:
return self._buffer.write(b"\xc3")
return self._buffer.write(b"\xc2")
if check(obj, int):
if 0 <= obj < 0x80:
return self._buffer.write(struct.pack("B", obj))
if -0x20 <= obj < 0:
return self._buffer.write(struct.pack("b", obj))
if 0x80 <= obj <= 0xFF:
return self._buffer.write(struct.pack("BB", 0xCC, obj))
if -0x80 <= obj < 0:
return self._buffer.write(struct.pack(">Bb", 0xD0, obj))
if 0xFF < obj <= 0xFFFF:
return self._buffer.write(struct.pack(">BH", 0xCD, obj))
if -0x8000 <= obj < -0x80:
return self._buffer.write(struct.pack(">Bh", 0xD1, obj))
if 0xFFFF < obj <= 0xFFFFFFFF:
return self._buffer.write(struct.pack(">BI", 0xCE, obj))
if -0x80000000 <= obj < -0x8000:
return self._buffer.write(struct.pack(">Bi", 0xD2, obj))
if 0xFFFFFFFF < obj <= 0xFFFFFFFFFFFFFFFF:
return self._buffer.write(struct.pack(">BQ", 0xCF, obj))
if -0x8000000000000000 <= obj < -0x80000000:
return self._buffer.write(struct.pack(">Bq", 0xD3, obj))
if not default_used and self._default is not None:
obj = self._default(obj)
default_used = True
continue
raise OverflowError("Integer value out of range")
if check(obj, (bytes, bytearray)):
n = len(obj)
if n >= 2**32:
raise ValueError("%s is too large" % type(obj).__name__)
self._pack_bin_header(n)
return self._buffer.write(obj)
if check(obj, str):
obj = obj.encode("utf-8", self._unicode_errors)
n = len(obj)
if n >= 2**32:
raise ValueError("String is too large")
self._pack_raw_header(n)
return self._buffer.write(obj)
if check(obj, memoryview):
n = obj.nbytes
if n >= 2**32:
raise ValueError("Memoryview is too large")
self._pack_bin_header(n)
return self._buffer.write(obj)
if check(obj, float):
if self._use_float:
return self._buffer.write(struct.pack(">Bf", 0xCA, obj))
return self._buffer.write(struct.pack(">Bd", 0xCB, obj))
if check(obj, (ExtType, Timestamp)):
if check(obj, Timestamp):
code = -1
data = obj.to_bytes()
else:
code = obj.code
data = obj.data
assert isinstance(code, int)
assert isinstance(data, bytes)
L = len(data)
if L == 1:
self._buffer.write(b"\xd4")
elif L == 2:
self._buffer.write(b"\xd5")
elif L == 4:
self._buffer.write(b"\xd6")
elif L == 8:
self._buffer.write(b"\xd7")
elif L == 16:
self._buffer.write(b"\xd8")
elif L <= 0xFF:
self._buffer.write(struct.pack(">BB", 0xC7, L))
elif L <= 0xFFFF:
self._buffer.write(struct.pack(">BH", 0xC8, L))
else:
self._buffer.write(struct.pack(">BI", 0xC9, L))
self._buffer.write(struct.pack("b", code))
self._buffer.write(data)
return
if check(obj, list_types):
n = len(obj)
self._pack_array_header(n)
for i in range(n):
self._pack(obj[i], nest_limit - 1)
return
if check(obj, dict):
return self._pack_map_pairs(len(obj), obj.items(), nest_limit - 1)
if self._datetime and check(obj, _DateTime) and obj.tzinfo is not None:
obj = Timestamp.from_datetime(obj)
default_used = 1
continue
if not default_used and self._default is not None:
obj = self._default(obj)
default_used = 1
continue
if self._datetime and check(obj, _DateTime):
raise ValueError(f"Cannot serialize {obj!r} where tzinfo=None")
raise TypeError(f"Cannot serialize {obj!r}")
def pack(self, obj):
try:
self._pack(obj)
except:
self._buffer = BytesIO() # force reset
raise
if self._autoreset:
ret = self._buffer.getvalue()
self._buffer = BytesIO()
return ret
def pack_map_pairs(self, pairs):
self._pack_map_pairs(len(pairs), pairs)
if self._autoreset:
ret = self._buffer.getvalue()
self._buffer = BytesIO()
return ret
def pack_array_header(self, n):
if n >= 2**32:
raise ValueError
self._pack_array_header(n)
if self._autoreset:
ret = self._buffer.getvalue()
self._buffer = BytesIO()
return ret
def pack_map_header(self, n):
if n >= 2**32:
raise ValueError
self._pack_map_header(n)
if self._autoreset:
ret = self._buffer.getvalue()
self._buffer = BytesIO()
return ret
def pack_ext_type(self, typecode, data):
if not isinstance(typecode, int):
raise TypeError("typecode must have int type.")
if not 0 <= typecode <= 127:
raise ValueError("typecode should be 0-127")
if not isinstance(data, bytes):
raise TypeError("data must have bytes type")
L = len(data)
if L > 0xFFFFFFFF:
raise ValueError("Too large data")
if L == 1:
self._buffer.write(b"\xd4")
elif L == 2:
self._buffer.write(b"\xd5")
elif L == 4:
self._buffer.write(b"\xd6")
elif L == 8:
self._buffer.write(b"\xd7")
elif L == 16:
self._buffer.write(b"\xd8")
elif L <= 0xFF:
self._buffer.write(b"\xc7" + struct.pack("B", L))
elif L <= 0xFFFF:
self._buffer.write(b"\xc8" + struct.pack(">H", L))
else:
self._buffer.write(b"\xc9" + struct.pack(">I", L))
self._buffer.write(struct.pack("B", typecode))
self._buffer.write(data)
def _pack_array_header(self, n):
if n <= 0x0F:
return self._buffer.write(struct.pack("B", 0x90 + n))
if n <= 0xFFFF:
return self._buffer.write(struct.pack(">BH", 0xDC, n))
if n <= 0xFFFFFFFF:
return self._buffer.write(struct.pack(">BI", 0xDD, n))
raise ValueError("Array is too large")
def _pack_map_header(self, n):
if n <= 0x0F:
return self._buffer.write(struct.pack("B", 0x80 + n))
if n <= 0xFFFF:
return self._buffer.write(struct.pack(">BH", 0xDE, n))
if n <= 0xFFFFFFFF:
return self._buffer.write(struct.pack(">BI", 0xDF, n))
raise ValueError("Dict is too large")
def _pack_map_pairs(self, n, pairs, nest_limit=DEFAULT_RECURSE_LIMIT):
self._pack_map_header(n)
for k, v in pairs:
self._pack(k, nest_limit - 1)
self._pack(v, nest_limit - 1)
def _pack_raw_header(self, n):
if n <= 0x1F:
self._buffer.write(struct.pack("B", 0xA0 + n))
elif self._use_bin_type and n <= 0xFF:
self._buffer.write(struct.pack(">BB", 0xD9, n))
elif n <= 0xFFFF:
self._buffer.write(struct.pack(">BH", 0xDA, n))
elif n <= 0xFFFFFFFF:
self._buffer.write(struct.pack(">BI", 0xDB, n))
else:
raise ValueError("Raw is too large")
def _pack_bin_header(self, n):
if not self._use_bin_type:
return self._pack_raw_header(n)
elif n <= 0xFF:
return self._buffer.write(struct.pack(">BB", 0xC4, n))
elif n <= 0xFFFF:
return self._buffer.write(struct.pack(">BH", 0xC5, n))
elif n <= 0xFFFFFFFF:
return self._buffer.write(struct.pack(">BI", 0xC6, n))
else:
raise ValueError("Bin is too large")
def bytes(self):
"""Return internal buffer contents as bytes object"""
return self._buffer.getvalue()
def reset(self):
"""Reset internal buffer.
This method is useful only when autoreset=False.
"""
self._buffer = BytesIO()
def getbuffer(self):
"""Return view of internal buffer."""
if _USING_STRINGBUILDER:
return memoryview(self.bytes())
else:
return self._buffer.getbuffer()
srsly-release-v2.5.1/srsly/msgpack/pack.h 0000664 0000000 0000000 00000003165 14742310675 0020405 0 ustar 00root root 0000000 0000000 /*
* MessagePack for Python packing routine
*
* Copyright (C) 2009 Naoki INADA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include
#include
#include "sysdep.h"
#include
#include
#include
#ifdef __cplusplus
extern "C" {
#endif
typedef struct msgpack_packer {
char *buf;
size_t length;
size_t buf_size;
bool use_bin_type;
} msgpack_packer;
typedef struct Packer Packer;
static inline int msgpack_pack_write(msgpack_packer* pk, const char *data, size_t l)
{
char* buf = pk->buf;
size_t bs = pk->buf_size;
size_t len = pk->length;
if (len + l > bs) {
bs = (len + l) * 2;
buf = (char*)PyMem_Realloc(buf, bs);
if (!buf) {
PyErr_NoMemory();
return -1;
}
}
memcpy(buf + len, data, l);
len += l;
pk->buf = buf;
pk->buf_size = bs;
pk->length = len;
return 0;
}
#define msgpack_pack_append_buffer(user, buf, len) \
return msgpack_pack_write(user, (const char*)buf, len)
#include "pack_template.h"
#ifdef __cplusplus
}
#endif
srsly-release-v2.5.1/srsly/msgpack/pack_template.h 0000664 0000000 0000000 00000040572 14742310675 0022303 0 ustar 00root root 0000000 0000000 /*
* MessagePack packing routine template
*
* Copyright (C) 2008-2010 FURUHASHI Sadayuki
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#if defined(__LITTLE_ENDIAN__)
#define TAKE8_8(d) ((uint8_t*)&d)[0]
#define TAKE8_16(d) ((uint8_t*)&d)[0]
#define TAKE8_32(d) ((uint8_t*)&d)[0]
#define TAKE8_64(d) ((uint8_t*)&d)[0]
#elif defined(__BIG_ENDIAN__)
#define TAKE8_8(d) ((uint8_t*)&d)[0]
#define TAKE8_16(d) ((uint8_t*)&d)[1]
#define TAKE8_32(d) ((uint8_t*)&d)[3]
#define TAKE8_64(d) ((uint8_t*)&d)[7]
#endif
#ifndef msgpack_pack_append_buffer
#error msgpack_pack_append_buffer callback is not defined
#endif
/*
* Integer
*/
#define msgpack_pack_real_uint16(x, d) \
do { \
if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_16(d), 1); \
} else if(d < (1<<8)) { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_16(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} else { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} \
} while(0)
#define msgpack_pack_real_uint32(x, d) \
do { \
if(d < (1<<8)) { \
if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_32(d), 1); \
} else { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_32(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} else { \
if(d < (1<<16)) { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* unsigned 32 */ \
unsigned char buf[5]; \
buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} \
} \
} while(0)
#define msgpack_pack_real_uint64(x, d) \
do { \
if(d < (1ULL<<8)) { \
if(d < (1ULL<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_64(d), 1); \
} else { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_64(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} else { \
if(d < (1ULL<<16)) { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else if(d < (1ULL<<32)) { \
/* unsigned 32 */ \
unsigned char buf[5]; \
buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} else { \
/* unsigned 64 */ \
unsigned char buf[9]; \
buf[0] = 0xcf; _msgpack_store64(&buf[1], d); \
msgpack_pack_append_buffer(x, buf, 9); \
} \
} \
} while(0)
#define msgpack_pack_real_int16(x, d) \
do { \
if(d < -(1<<5)) { \
if(d < -(1<<7)) { \
/* signed 16 */ \
unsigned char buf[3]; \
buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* signed 8 */ \
unsigned char buf[2] = {0xd0, TAKE8_16(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} else if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_16(d), 1); \
} else { \
if(d < (1<<8)) { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_16(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} else { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} \
} \
} while(0)
#define msgpack_pack_real_int32(x, d) \
do { \
if(d < -(1<<5)) { \
if(d < -(1<<15)) { \
/* signed 32 */ \
unsigned char buf[5]; \
buf[0] = 0xd2; _msgpack_store32(&buf[1], (int32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} else if(d < -(1<<7)) { \
/* signed 16 */ \
unsigned char buf[3]; \
buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* signed 8 */ \
unsigned char buf[2] = {0xd0, TAKE8_32(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} else if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_32(d), 1); \
} else { \
if(d < (1<<8)) { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_32(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} else if(d < (1<<16)) { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* unsigned 32 */ \
unsigned char buf[5]; \
buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} \
} \
} while(0)
#define msgpack_pack_real_int64(x, d) \
do { \
if(d < -(1LL<<5)) { \
if(d < -(1LL<<15)) { \
if(d < -(1LL<<31)) { \
/* signed 64 */ \
unsigned char buf[9]; \
buf[0] = 0xd3; _msgpack_store64(&buf[1], d); \
msgpack_pack_append_buffer(x, buf, 9); \
} else { \
/* signed 32 */ \
unsigned char buf[5]; \
buf[0] = 0xd2; _msgpack_store32(&buf[1], (int32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} \
} else { \
if(d < -(1<<7)) { \
/* signed 16 */ \
unsigned char buf[3]; \
buf[0] = 0xd1; _msgpack_store16(&buf[1], (int16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} else { \
/* signed 8 */ \
unsigned char buf[2] = {0xd0, TAKE8_64(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} \
} \
} else if(d < (1<<7)) { \
/* fixnum */ \
msgpack_pack_append_buffer(x, &TAKE8_64(d), 1); \
} else { \
if(d < (1LL<<16)) { \
if(d < (1<<8)) { \
/* unsigned 8 */ \
unsigned char buf[2] = {0xcc, TAKE8_64(d)}; \
msgpack_pack_append_buffer(x, buf, 2); \
} else { \
/* unsigned 16 */ \
unsigned char buf[3]; \
buf[0] = 0xcd; _msgpack_store16(&buf[1], (uint16_t)d); \
msgpack_pack_append_buffer(x, buf, 3); \
} \
} else { \
if(d < (1LL<<32)) { \
/* unsigned 32 */ \
unsigned char buf[5]; \
buf[0] = 0xce; _msgpack_store32(&buf[1], (uint32_t)d); \
msgpack_pack_append_buffer(x, buf, 5); \
} else { \
/* unsigned 64 */ \
unsigned char buf[9]; \
buf[0] = 0xcf; _msgpack_store64(&buf[1], d); \
msgpack_pack_append_buffer(x, buf, 9); \
} \
} \
} \
} while(0)
static inline int msgpack_pack_short(msgpack_packer* x, short d)
{
#if defined(SIZEOF_SHORT)
#if SIZEOF_SHORT == 2
msgpack_pack_real_int16(x, d);
#elif SIZEOF_SHORT == 4
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#elif defined(SHRT_MAX)
#if SHRT_MAX == 0x7fff
msgpack_pack_real_int16(x, d);
#elif SHRT_MAX == 0x7fffffff
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#else
if(sizeof(short) == 2) {
msgpack_pack_real_int16(x, d);
} else if(sizeof(short) == 4) {
msgpack_pack_real_int32(x, d);
} else {
msgpack_pack_real_int64(x, d);
}
#endif
}
static inline int msgpack_pack_int(msgpack_packer* x, int d)
{
#if defined(SIZEOF_INT)
#if SIZEOF_INT == 2
msgpack_pack_real_int16(x, d);
#elif SIZEOF_INT == 4
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#elif defined(INT_MAX)
#if INT_MAX == 0x7fff
msgpack_pack_real_int16(x, d);
#elif INT_MAX == 0x7fffffff
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#else
if(sizeof(int) == 2) {
msgpack_pack_real_int16(x, d);
} else if(sizeof(int) == 4) {
msgpack_pack_real_int32(x, d);
} else {
msgpack_pack_real_int64(x, d);
}
#endif
}
static inline int msgpack_pack_long(msgpack_packer* x, long d)
{
#if defined(SIZEOF_LONG)
#if SIZEOF_LONG == 4
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#elif defined(LONG_MAX)
#if LONG_MAX == 0x7fffffffL
msgpack_pack_real_int32(x, d);
#else
msgpack_pack_real_int64(x, d);
#endif
#else
if (sizeof(long) == 4) {
msgpack_pack_real_int32(x, d);
} else {
msgpack_pack_real_int64(x, d);
}
#endif
}
static inline int msgpack_pack_long_long(msgpack_packer* x, long long d)
{
msgpack_pack_real_int64(x, d);
}
static inline int msgpack_pack_unsigned_long_long(msgpack_packer* x, unsigned long long d)
{
msgpack_pack_real_uint64(x, d);
}
/*
* Float
*/
static inline int msgpack_pack_float(msgpack_packer* x, float d)
{
unsigned char buf[5];
buf[0] = 0xca;
#if PY_VERSION_HEX >= 0x030B00A7
PyFloat_Pack4(d, (char *)&buf[1], 0);
#else
_PyFloat_Pack4(d, &buf[1], 0);
#endif
msgpack_pack_append_buffer(x, buf, 5);
}
static inline int msgpack_pack_double(msgpack_packer* x, double d)
{
unsigned char buf[9];
buf[0] = 0xcb;
#if PY_VERSION_HEX >= 0x030B00A7
PyFloat_Pack8(d, (char *)&buf[1], 0);
#else
_PyFloat_Pack8(d, &buf[1], 0);
#endif
msgpack_pack_append_buffer(x, buf, 9);
}
/*
* Nil
*/
static inline int msgpack_pack_nil(msgpack_packer* x)
{
static const unsigned char d = 0xc0;
msgpack_pack_append_buffer(x, &d, 1);
}
/*
* Boolean
*/
static inline int msgpack_pack_true(msgpack_packer* x)
{
static const unsigned char d = 0xc3;
msgpack_pack_append_buffer(x, &d, 1);
}
static inline int msgpack_pack_false(msgpack_packer* x)
{
static const unsigned char d = 0xc2;
msgpack_pack_append_buffer(x, &d, 1);
}
/*
* Array
*/
static inline int msgpack_pack_array(msgpack_packer* x, unsigned int n)
{
if(n < 16) {
unsigned char d = 0x90 | n;
msgpack_pack_append_buffer(x, &d, 1);
} else if(n < 65536) {
unsigned char buf[3];
buf[0] = 0xdc; _msgpack_store16(&buf[1], (uint16_t)n);
msgpack_pack_append_buffer(x, buf, 3);
} else {
unsigned char buf[5];
buf[0] = 0xdd; _msgpack_store32(&buf[1], (uint32_t)n);
msgpack_pack_append_buffer(x, buf, 5);
}
}
/*
* Map
*/
static inline int msgpack_pack_map(msgpack_packer* x, unsigned int n)
{
if(n < 16) {
unsigned char d = 0x80 | n;
msgpack_pack_append_buffer(x, &TAKE8_8(d), 1);
} else if(n < 65536) {
unsigned char buf[3];
buf[0] = 0xde; _msgpack_store16(&buf[1], (uint16_t)n);
msgpack_pack_append_buffer(x, buf, 3);
} else {
unsigned char buf[5];
buf[0] = 0xdf; _msgpack_store32(&buf[1], (uint32_t)n);
msgpack_pack_append_buffer(x, buf, 5);
}
}
/*
* Raw
*/
static inline int msgpack_pack_raw(msgpack_packer* x, size_t l)
{
if (l < 32) {
unsigned char d = 0xa0 | (uint8_t)l;
msgpack_pack_append_buffer(x, &TAKE8_8(d), 1);
} else if (x->use_bin_type && l < 256) { // str8 is new format introduced with bin.
unsigned char buf[2] = {0xd9, (uint8_t)l};
msgpack_pack_append_buffer(x, buf, 2);
} else if (l < 65536) {
unsigned char buf[3];
buf[0] = 0xda; _msgpack_store16(&buf[1], (uint16_t)l);
msgpack_pack_append_buffer(x, buf, 3);
} else {
unsigned char buf[5];
buf[0] = 0xdb; _msgpack_store32(&buf[1], (uint32_t)l);
msgpack_pack_append_buffer(x, buf, 5);
}
}
/*
* bin
*/
static inline int msgpack_pack_bin(msgpack_packer *x, size_t l)
{
if (!x->use_bin_type) {
return msgpack_pack_raw(x, l);
}
if (l < 256) {
unsigned char buf[2] = {0xc4, (unsigned char)l};
msgpack_pack_append_buffer(x, buf, 2);
} else if (l < 65536) {
unsigned char buf[3] = {0xc5};
_msgpack_store16(&buf[1], (uint16_t)l);
msgpack_pack_append_buffer(x, buf, 3);
} else {
unsigned char buf[5] = {0xc6};
_msgpack_store32(&buf[1], (uint32_t)l);
msgpack_pack_append_buffer(x, buf, 5);
}
}
static inline int msgpack_pack_raw_body(msgpack_packer* x, const void* b, size_t l)
{
if (l > 0) msgpack_pack_append_buffer(x, (const unsigned char*)b, l);
return 0;
}
/*
* Ext
*/
static inline int msgpack_pack_ext(msgpack_packer* x, char typecode, size_t l)
{
if (l == 1) {
unsigned char buf[2];
buf[0] = 0xd4;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l == 2) {
unsigned char buf[2];
buf[0] = 0xd5;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l == 4) {
unsigned char buf[2];
buf[0] = 0xd6;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l == 8) {
unsigned char buf[2];
buf[0] = 0xd7;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l == 16) {
unsigned char buf[2];
buf[0] = 0xd8;
buf[1] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 2);
}
else if(l < 256) {
unsigned char buf[3];
buf[0] = 0xc7;
buf[1] = l;
buf[2] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 3);
} else if(l < 65536) {
unsigned char buf[4];
buf[0] = 0xc8;
_msgpack_store16(&buf[1], (uint16_t)l);
buf[3] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 4);
} else {
unsigned char buf[6];
buf[0] = 0xc9;
_msgpack_store32(&buf[1], (uint32_t)l);
buf[5] = (unsigned char)typecode;
msgpack_pack_append_buffer(x, buf, 6);
}
}
/*
* Pack Timestamp extension type. Follows msgpack-c pack_template.h.
*/
static inline int msgpack_pack_timestamp(msgpack_packer* x, int64_t seconds, uint32_t nanoseconds)
{
if ((seconds >> 34) == 0) {
/* seconds is unsigned and fits in 34 bits */
uint64_t data64 = ((uint64_t)nanoseconds << 34) | (uint64_t)seconds;
if ((data64 & 0xffffffff00000000L) == 0) {
/* no nanoseconds and seconds is 32bits or smaller. timestamp32. */
unsigned char buf[4];
uint32_t data32 = (uint32_t)data64;
msgpack_pack_ext(x, -1, 4);
_msgpack_store32(buf, data32);
msgpack_pack_raw_body(x, buf, 4);
} else {
/* timestamp64 */
unsigned char buf[8];
msgpack_pack_ext(x, -1, 8);
_msgpack_store64(buf, data64);
msgpack_pack_raw_body(x, buf, 8);
}
} else {
/* seconds is signed or >34bits */
unsigned char buf[12];
_msgpack_store32(&buf[0], nanoseconds);
_msgpack_store64(&buf[4], seconds);
msgpack_pack_ext(x, -1, 12);
msgpack_pack_raw_body(x, buf, 12);
}
return 0;
}
#undef msgpack_pack_append_buffer
#undef TAKE8_8
#undef TAKE8_16
#undef TAKE8_32
#undef TAKE8_64
#undef msgpack_pack_real_uint16
#undef msgpack_pack_real_uint32
#undef msgpack_pack_real_uint64
#undef msgpack_pack_real_int16
#undef msgpack_pack_real_int32
#undef msgpack_pack_real_int64
srsly-release-v2.5.1/srsly/msgpack/sysdep.h 0000664 0000000 0000000 00000014464 14742310675 0021002 0 ustar 00root root 0000000 0000000 /*
* MessagePack system dependencies
*
* Copyright (C) 2008-2010 FURUHASHI Sadayuki
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MSGPACK_SYSDEP_H__
#define MSGPACK_SYSDEP_H__
#include
#include
#if defined(_MSC_VER) && _MSC_VER < 1600
typedef __int8 int8_t;
typedef unsigned __int8 uint8_t;
typedef __int16 int16_t;
typedef unsigned __int16 uint16_t;
typedef __int32 int32_t;
typedef unsigned __int32 uint32_t;
typedef __int64 int64_t;
typedef unsigned __int64 uint64_t;
#elif defined(_MSC_VER) // && _MSC_VER >= 1600
#include
#else
#include
#include
#endif
#ifdef _WIN32
#define _msgpack_atomic_counter_header
typedef long _msgpack_atomic_counter_t;
#define _msgpack_sync_decr_and_fetch(ptr) InterlockedDecrement(ptr)
#define _msgpack_sync_incr_and_fetch(ptr) InterlockedIncrement(ptr)
#elif defined(__GNUC__) && ((__GNUC__*10 + __GNUC_MINOR__) < 41)
#define _msgpack_atomic_counter_header "gcc_atomic.h"
#else
typedef unsigned int _msgpack_atomic_counter_t;
#define _msgpack_sync_decr_and_fetch(ptr) __sync_sub_and_fetch(ptr, 1)
#define _msgpack_sync_incr_and_fetch(ptr) __sync_add_and_fetch(ptr, 1)
#endif
#ifdef _WIN32
#ifdef __cplusplus
/* numeric_limits::min,max */
#ifdef max
#undef max
#endif
#ifdef min
#undef min
#endif
#endif
#else /* _WIN32 */
#include /* ntohs, ntohl */
#endif
#if !defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__)
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
#define __LITTLE_ENDIAN__
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define __BIG_ENDIAN__
#elif _WIN32
#define __LITTLE_ENDIAN__
#endif
#endif
#ifdef __LITTLE_ENDIAN__
#ifdef _WIN32
# if defined(ntohs)
# define _msgpack_be16(x) ntohs(x)
# elif defined(_byteswap_ushort) || (defined(_MSC_VER) && _MSC_VER >= 1400)
# define _msgpack_be16(x) ((uint16_t)_byteswap_ushort((unsigned short)x))
# else
# define _msgpack_be16(x) ( \
((((uint16_t)x) << 8) ) | \
((((uint16_t)x) >> 8) ) )
# endif
#else
# define _msgpack_be16(x) ntohs(x)
#endif
#ifdef _WIN32
# if defined(ntohl)
# define _msgpack_be32(x) ntohl(x)
# elif defined(_byteswap_ulong) || defined(_MSC_VER)
# define _msgpack_be32(x) ((uint32_t)_byteswap_ulong((unsigned long)x))
# else
# define _msgpack_be32(x) \
( ((((uint32_t)x) << 24) ) | \
((((uint32_t)x) << 8) & 0x00ff0000U ) | \
((((uint32_t)x) >> 8) & 0x0000ff00U ) | \
((((uint32_t)x) >> 24) ) )
# endif
#else
# define _msgpack_be32(x) ntohl(x)
#endif
#if defined(_byteswap_uint64) || defined(_MSC_VER)
# define _msgpack_be64(x) (_byteswap_uint64(x))
#elif defined(bswap_64)
# define _msgpack_be64(x) bswap_64(x)
#elif defined(__DARWIN_OSSwapInt64)
# define _msgpack_be64(x) __DARWIN_OSSwapInt64(x)
#else
#define _msgpack_be64(x) \
( ((((uint64_t)x) << 56) ) | \
((((uint64_t)x) << 40) & 0x00ff000000000000ULL ) | \
((((uint64_t)x) << 24) & 0x0000ff0000000000ULL ) | \
((((uint64_t)x) << 8) & 0x000000ff00000000ULL ) | \
((((uint64_t)x) >> 8) & 0x00000000ff000000ULL ) | \
((((uint64_t)x) >> 24) & 0x0000000000ff0000ULL ) | \
((((uint64_t)x) >> 40) & 0x000000000000ff00ULL ) | \
((((uint64_t)x) >> 56) ) )
#endif
#define _msgpack_load16(cast, from) ((cast)( \
(((uint16_t)((uint8_t*)(from))[0]) << 8) | \
(((uint16_t)((uint8_t*)(from))[1]) ) ))
#define _msgpack_load32(cast, from) ((cast)( \
(((uint32_t)((uint8_t*)(from))[0]) << 24) | \
(((uint32_t)((uint8_t*)(from))[1]) << 16) | \
(((uint32_t)((uint8_t*)(from))[2]) << 8) | \
(((uint32_t)((uint8_t*)(from))[3]) ) ))
#define _msgpack_load64(cast, from) ((cast)( \
(((uint64_t)((uint8_t*)(from))[0]) << 56) | \
(((uint64_t)((uint8_t*)(from))[1]) << 48) | \
(((uint64_t)((uint8_t*)(from))[2]) << 40) | \
(((uint64_t)((uint8_t*)(from))[3]) << 32) | \
(((uint64_t)((uint8_t*)(from))[4]) << 24) | \
(((uint64_t)((uint8_t*)(from))[5]) << 16) | \
(((uint64_t)((uint8_t*)(from))[6]) << 8) | \
(((uint64_t)((uint8_t*)(from))[7]) ) ))
#else
#define _msgpack_be16(x) (x)
#define _msgpack_be32(x) (x)
#define _msgpack_be64(x) (x)
#define _msgpack_load16(cast, from) ((cast)( \
(((uint16_t)((uint8_t*)from)[0]) << 8) | \
(((uint16_t)((uint8_t*)from)[1]) ) ))
#define _msgpack_load32(cast, from) ((cast)( \
(((uint32_t)((uint8_t*)from)[0]) << 24) | \
(((uint32_t)((uint8_t*)from)[1]) << 16) | \
(((uint32_t)((uint8_t*)from)[2]) << 8) | \
(((uint32_t)((uint8_t*)from)[3]) ) ))
#define _msgpack_load64(cast, from) ((cast)( \
(((uint64_t)((uint8_t*)from)[0]) << 56) | \
(((uint64_t)((uint8_t*)from)[1]) << 48) | \
(((uint64_t)((uint8_t*)from)[2]) << 40) | \
(((uint64_t)((uint8_t*)from)[3]) << 32) | \
(((uint64_t)((uint8_t*)from)[4]) << 24) | \
(((uint64_t)((uint8_t*)from)[5]) << 16) | \
(((uint64_t)((uint8_t*)from)[6]) << 8) | \
(((uint64_t)((uint8_t*)from)[7]) ) ))
#endif
#define _msgpack_store16(to, num) \
do { uint16_t val = _msgpack_be16(num); memcpy(to, &val, 2); } while(0)
#define _msgpack_store32(to, num) \
do { uint32_t val = _msgpack_be32(num); memcpy(to, &val, 4); } while(0)
#define _msgpack_store64(to, num) \
do { uint64_t val = _msgpack_be64(num); memcpy(to, &val, 8); } while(0)
/*
#define _msgpack_load16(cast, from) \
({ cast val; memcpy(&val, (char*)from, 2); _msgpack_be16(val); })
#define _msgpack_load32(cast, from) \
({ cast val; memcpy(&val, (char*)from, 4); _msgpack_be32(val); })
#define _msgpack_load64(cast, from) \
({ cast val; memcpy(&val, (char*)from, 8); _msgpack_be64(val); })
*/
#endif /* msgpack/sysdep.h */
srsly-release-v2.5.1/srsly/msgpack/unpack.h 0000664 0000000 0000000 00000025541 14742310675 0020752 0 ustar 00root root 0000000 0000000 /*
* MessagePack for Python unpacking routine
*
* Copyright (C) 2009 Naoki INADA
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define MSGPACK_EMBED_STACK_SIZE (1024)
#include "unpack_define.h"
typedef struct unpack_user {
bool use_list;
bool raw;
bool has_pairs_hook;
bool strict_map_key;
int timestamp;
PyObject *object_hook;
PyObject *list_hook;
PyObject *ext_hook;
PyObject *timestamp_t;
PyObject *giga;
PyObject *utc;
const char *encoding;
const char *unicode_errors;
Py_ssize_t max_str_len, max_bin_len, max_array_len, max_map_len, max_ext_len;
} unpack_user;
typedef PyObject* msgpack_unpack_object;
struct unpack_context;
typedef struct unpack_context unpack_context;
typedef int (*execute_fn)(unpack_context *ctx, const char* data, Py_ssize_t len, Py_ssize_t* off);
static inline msgpack_unpack_object unpack_callback_root(unpack_user* u)
{
return NULL;
}
static inline int unpack_callback_uint16(unpack_user* u, uint16_t d, msgpack_unpack_object* o)
{
PyObject *p = PyLong_FromLong((long)d);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_uint8(unpack_user* u, uint8_t d, msgpack_unpack_object* o)
{
return unpack_callback_uint16(u, d, o);
}
static inline int unpack_callback_uint32(unpack_user* u, uint32_t d, msgpack_unpack_object* o)
{
PyObject *p = PyLong_FromSize_t((size_t)d);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_uint64(unpack_user* u, uint64_t d, msgpack_unpack_object* o)
{
PyObject *p;
if (d > LONG_MAX) {
p = PyLong_FromUnsignedLongLong((unsigned PY_LONG_LONG)d);
} else {
p = PyLong_FromLong((long)d);
}
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_int32(unpack_user* u, int32_t d, msgpack_unpack_object* o)
{
PyObject *p = PyLong_FromLong(d);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_int16(unpack_user* u, int16_t d, msgpack_unpack_object* o)
{
return unpack_callback_int32(u, d, o);
}
static inline int unpack_callback_int8(unpack_user* u, int8_t d, msgpack_unpack_object* o)
{
return unpack_callback_int32(u, d, o);
}
static inline int unpack_callback_int64(unpack_user* u, int64_t d, msgpack_unpack_object* o)
{
PyObject *p;
if (d > LONG_MAX || d < LONG_MIN) {
p = PyLong_FromLongLong((PY_LONG_LONG)d);
} else {
p = PyLong_FromLong((long)d);
}
*o = p;
return 0;
}
static inline int unpack_callback_double(unpack_user* u, double d, msgpack_unpack_object* o)
{
PyObject *p = PyFloat_FromDouble(d);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_float(unpack_user* u, float d, msgpack_unpack_object* o)
{
return unpack_callback_double(u, d, o);
}
static inline int unpack_callback_nil(unpack_user* u, msgpack_unpack_object* o)
{ Py_INCREF(Py_None); *o = Py_None; return 0; }
static inline int unpack_callback_true(unpack_user* u, msgpack_unpack_object* o)
{ Py_INCREF(Py_True); *o = Py_True; return 0; }
static inline int unpack_callback_false(unpack_user* u, msgpack_unpack_object* o)
{ Py_INCREF(Py_False); *o = Py_False; return 0; }
static inline int unpack_callback_array(unpack_user* u, unsigned int n, msgpack_unpack_object* o)
{
if (n > u->max_array_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_array_len(%zd)", n, u->max_array_len);
return -1;
}
PyObject *p = u->use_list ? PyList_New(n) : PyTuple_New(n);
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_array_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object o)
{
if (u->use_list)
PyList_SET_ITEM(*c, current, o);
else
PyTuple_SET_ITEM(*c, current, o);
return 0;
}
static inline int unpack_callback_array_end(unpack_user* u, msgpack_unpack_object* c)
{
if (u->list_hook) {
PyObject *new_c = PyObject_CallFunctionObjArgs(u->list_hook, *c, NULL);
if (!new_c)
return -1;
Py_DECREF(*c);
*c = new_c;
}
return 0;
}
static inline int unpack_callback_map(unpack_user* u, unsigned int n, msgpack_unpack_object* o)
{
if (n > u->max_map_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_map_len(%zd)", n, u->max_map_len);
return -1;
}
PyObject *p;
if (u->has_pairs_hook) {
p = PyList_New(n); // Or use tuple?
}
else {
p = PyDict_New();
}
if (!p)
return -1;
*o = p;
return 0;
}
static inline int unpack_callback_map_item(unpack_user* u, unsigned int current, msgpack_unpack_object* c, msgpack_unpack_object k, msgpack_unpack_object v)
{
if (u->strict_map_key && !PyUnicode_CheckExact(k) && !PyBytes_CheckExact(k)) {
PyErr_Format(PyExc_ValueError, "%.100s is not allowed for map key when strict_map_key=True", Py_TYPE(k)->tp_name);
return -1;
}
if (PyUnicode_CheckExact(k)) {
PyUnicode_InternInPlace(&k);
}
if (u->has_pairs_hook) {
msgpack_unpack_object item = PyTuple_Pack(2, k, v);
if (!item)
return -1;
Py_DECREF(k);
Py_DECREF(v);
PyList_SET_ITEM(*c, current, item);
return 0;
}
else if (PyDict_SetItem(*c, k, v) == 0) {
Py_DECREF(k);
Py_DECREF(v);
return 0;
}
return -1;
}
static inline int unpack_callback_map_end(unpack_user* u, msgpack_unpack_object* c)
{
if (u->object_hook) {
PyObject *new_c = PyObject_CallFunctionObjArgs(u->object_hook, *c, NULL);
if (!new_c)
return -1;
Py_DECREF(*c);
*c = new_c;
}
return 0;
}
static inline int unpack_callback_raw(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o)
{
if (l > u->max_str_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_str_len(%zd)", l, u->max_str_len);
return -1;
}
PyObject *py;
if (u->encoding) {
py = PyUnicode_Decode(p, l, u->encoding, u->unicode_errors);
} else if (u->raw) {
py = PyBytes_FromStringAndSize(p, l);
} else {
py = PyUnicode_DecodeUTF8(p, l, u->unicode_errors);
}
if (!py)
return -1;
*o = py;
return 0;
}
static inline int unpack_callback_bin(unpack_user* u, const char* b, const char* p, unsigned int l, msgpack_unpack_object* o)
{
if (l > u->max_bin_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_bin_len(%zd)", l, u->max_bin_len);
return -1;
}
PyObject *py = PyBytes_FromStringAndSize(p, l);
if (!py)
return -1;
*o = py;
return 0;
}
typedef struct msgpack_timestamp {
int64_t tv_sec;
uint32_t tv_nsec;
} msgpack_timestamp;
/*
* Unpack ext buffer to a timestamp. Pulled from msgpack-c timestamp.h.
*/
static int unpack_timestamp(const char* buf, unsigned int buflen, msgpack_timestamp* ts) {
switch (buflen) {
case 4:
ts->tv_nsec = 0;
{
uint32_t v = _msgpack_load32(uint32_t, buf);
ts->tv_sec = (int64_t)v;
}
return 0;
case 8: {
uint64_t value =_msgpack_load64(uint64_t, buf);
ts->tv_nsec = (uint32_t)(value >> 34);
ts->tv_sec = value & 0x00000003ffffffffLL;
return 0;
}
case 12:
ts->tv_nsec = _msgpack_load32(uint32_t, buf);
ts->tv_sec = _msgpack_load64(int64_t, buf + 4);
return 0;
default:
return -1;
}
}
#include "datetime.h"
static int unpack_callback_ext(unpack_user* u, const char* base, const char* pos,
unsigned int length, msgpack_unpack_object* o)
{
int8_t typecode = (int8_t)*pos++;
if (!u->ext_hook) {
PyErr_SetString(PyExc_AssertionError, "u->ext_hook cannot be NULL");
return -1;
}
if (length-1 > u->max_ext_len) {
PyErr_Format(PyExc_ValueError, "%u exceeds max_ext_len(%zd)", length, u->max_ext_len);
return -1;
}
PyObject *py = NULL;
// length also includes the typecode, so the actual data is length-1
if (typecode == -1) {
msgpack_timestamp ts;
if (unpack_timestamp(pos, length-1, &ts) < 0) {
return -1;
}
if (u->timestamp == 2) { // int
PyObject *a = PyLong_FromLongLong(ts.tv_sec);
if (a == NULL) return -1;
PyObject *c = PyNumber_Multiply(a, u->giga);
Py_DECREF(a);
if (c == NULL) {
return -1;
}
PyObject *b = PyLong_FromUnsignedLong(ts.tv_nsec);
if (b == NULL) {
Py_DECREF(c);
return -1;
}
py = PyNumber_Add(c, b);
Py_DECREF(c);
Py_DECREF(b);
}
else if (u->timestamp == 0) { // Timestamp
py = PyObject_CallFunction(u->timestamp_t, "(Lk)", ts.tv_sec, ts.tv_nsec);
}
else if (u->timestamp == 3) { // datetime
// Calculate datetime using epoch + delta
// due to limitations PyDateTime_FromTimestamp on Windows with negative timestamps
PyObject *epoch = PyDateTimeAPI->DateTime_FromDateAndTime(1970, 1, 1, 0, 0, 0, 0, u->utc, PyDateTimeAPI->DateTimeType);
if (epoch == NULL) {
return -1;
}
PyObject* d = PyDelta_FromDSU(ts.tv_sec/(24*3600), ts.tv_sec%(24*3600), ts.tv_nsec / 1000);
if (d == NULL) {
Py_DECREF(epoch);
return -1;
}
py = PyNumber_Add(epoch, d);
Py_DECREF(epoch);
Py_DECREF(d);
}
else { // float
PyObject *a = PyFloat_FromDouble((double)ts.tv_nsec);
if (a == NULL) return -1;
PyObject *b = PyNumber_TrueDivide(a, u->giga);
Py_DECREF(a);
if (b == NULL) return -1;
PyObject *c = PyLong_FromLongLong(ts.tv_sec);
if (c == NULL) {
Py_DECREF(b);
return -1;
}
a = PyNumber_Add(b, c);
Py_DECREF(b);
Py_DECREF(c);
py = a;
}
} else {
py = PyObject_CallFunction(u->ext_hook, "(iy#)", (int)typecode, pos, (Py_ssize_t)length-1);
}
if (!py)
return -1;
*o = py;
return 0;
}
#include "unpack_template.h"
srsly-release-v2.5.1/srsly/msgpack/unpack_container_header.h 0000664 0000000 0000000 00000002543 14742310675 0024321 0 ustar 00root root 0000000 0000000 static inline int unpack_container_header(unpack_context* ctx, const char* data, Py_ssize_t len, Py_ssize_t* off)
{
assert(len >= *off);
uint32_t size;
const unsigned char *const p = (unsigned char*)data + *off;
#define inc_offset(inc) \
if (len - *off < inc) \
return 0; \
*off += inc;
switch (*p) {
case var_offset:
inc_offset(3);
size = _msgpack_load16(uint16_t, p + 1);
break;
case var_offset + 1:
inc_offset(5);
size = _msgpack_load32(uint32_t, p + 1);
break;
#ifdef USE_CASE_RANGE
case fixed_offset + 0x0 ... fixed_offset + 0xf:
#else
case fixed_offset + 0x0:
case fixed_offset + 0x1:
case fixed_offset + 0x2:
case fixed_offset + 0x3:
case fixed_offset + 0x4:
case fixed_offset + 0x5:
case fixed_offset + 0x6:
case fixed_offset + 0x7:
case fixed_offset + 0x8:
case fixed_offset + 0x9:
case fixed_offset + 0xa:
case fixed_offset + 0xb:
case fixed_offset + 0xc:
case fixed_offset + 0xd:
case fixed_offset + 0xe:
case fixed_offset + 0xf:
#endif
++*off;
size = ((unsigned int)*p) & 0x0f;
break;
default:
PyErr_SetString(PyExc_ValueError, "Unexpected type header on stream");
return -1;
}
unpack_callback_uint32(&ctx->user, size, &ctx->stack[0].obj);
return 1;
}
srsly-release-v2.5.1/srsly/msgpack/unpack_define.h 0000664 0000000 0000000 00000004476 14742310675 0022270 0 ustar 00root root 0000000 0000000 /*
* MessagePack unpacking routine template
*
* Copyright (C) 2008-2010 FURUHASHI Sadayuki
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MSGPACK_UNPACK_DEFINE_H__
#define MSGPACK_UNPACK_DEFINE_H__
#include "msgpack/sysdep.h"
#include
#include
#include
#include
#ifdef __cplusplus
extern "C" {
#endif
#ifndef MSGPACK_EMBED_STACK_SIZE
#define MSGPACK_EMBED_STACK_SIZE 32
#endif
// CS is first byte & 0x1f
typedef enum {
CS_HEADER = 0x00, // nil
//CS_ = 0x01,
//CS_ = 0x02, // false
//CS_ = 0x03, // true
CS_BIN_8 = 0x04,
CS_BIN_16 = 0x05,
CS_BIN_32 = 0x06,
CS_EXT_8 = 0x07,
CS_EXT_16 = 0x08,
CS_EXT_32 = 0x09,
CS_FLOAT = 0x0a,
CS_DOUBLE = 0x0b,
CS_UINT_8 = 0x0c,
CS_UINT_16 = 0x0d,
CS_UINT_32 = 0x0e,
CS_UINT_64 = 0x0f,
CS_INT_8 = 0x10,
CS_INT_16 = 0x11,
CS_INT_32 = 0x12,
CS_INT_64 = 0x13,
//CS_FIXEXT1 = 0x14,
//CS_FIXEXT2 = 0x15,
//CS_FIXEXT4 = 0x16,
//CS_FIXEXT8 = 0x17,
//CS_FIXEXT16 = 0x18,
CS_RAW_8 = 0x19,
CS_RAW_16 = 0x1a,
CS_RAW_32 = 0x1b,
CS_ARRAY_16 = 0x1c,
CS_ARRAY_32 = 0x1d,
CS_MAP_16 = 0x1e,
CS_MAP_32 = 0x1f,
ACS_RAW_VALUE,
ACS_BIN_VALUE,
ACS_EXT_VALUE,
} msgpack_unpack_state;
typedef enum {
CT_ARRAY_ITEM,
CT_MAP_KEY,
CT_MAP_VALUE,
} msgpack_container_type;
#ifdef __cplusplus
}
#endif
#endif /* msgpack/unpack_define.h */
srsly-release-v2.5.1/srsly/msgpack/unpack_template.h 0000664 0000000 0000000 00000032632 14742310675 0022644 0 ustar 00root root 0000000 0000000 /*
* MessagePack unpacking routine template
*
* Copyright (C) 2008-2010 FURUHASHI Sadayuki
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef USE_CASE_RANGE
#if !defined(_MSC_VER)
#define USE_CASE_RANGE
#endif
#endif
typedef struct unpack_stack {
PyObject* obj;
Py_ssize_t size;
Py_ssize_t count;
unsigned int ct;
PyObject* map_key;
} unpack_stack;
struct unpack_context {
unpack_user user;
unsigned int cs;
unsigned int trail;
unsigned int top;
/*
unpack_stack* stack;
unsigned int stack_size;
unpack_stack embed_stack[MSGPACK_EMBED_STACK_SIZE];
*/
unpack_stack stack[MSGPACK_EMBED_STACK_SIZE];
};
static inline void unpack_init(unpack_context* ctx)
{
ctx->cs = CS_HEADER;
ctx->trail = 0;
ctx->top = 0;
/*
ctx->stack = ctx->embed_stack;
ctx->stack_size = MSGPACK_EMBED_STACK_SIZE;
*/
ctx->stack[0].obj = unpack_callback_root(&ctx->user);
}
/*
static inline void unpack_destroy(unpack_context* ctx)
{
if(ctx->stack_size != MSGPACK_EMBED_STACK_SIZE) {
free(ctx->stack);
}
}
*/
static inline PyObject* unpack_data(unpack_context* ctx)
{
return (ctx)->stack[0].obj;
}
static inline void unpack_clear(unpack_context *ctx)
{
Py_CLEAR(ctx->stack[0].obj);
}
static inline int unpack_execute(bool construct, unpack_context* ctx, const char* data, Py_ssize_t len, Py_ssize_t* off)
{
assert(len >= *off);
const unsigned char* p = (unsigned char*)data + *off;
const unsigned char* const pe = (unsigned char*)data + len;
const void* n = p;
unsigned int trail = ctx->trail;
unsigned int cs = ctx->cs;
unsigned int top = ctx->top;
unpack_stack* stack = ctx->stack;
/*
unsigned int stack_size = ctx->stack_size;
*/
unpack_user* user = &ctx->user;
PyObject* obj = NULL;
unpack_stack* c = NULL;
int ret;
#define construct_cb(name) \
construct && unpack_callback ## name
#define push_simple_value(func) \
if(construct_cb(func)(user, &obj) < 0) { goto _failed; } \
goto _push
#define push_fixed_value(func, arg) \
if(construct_cb(func)(user, arg, &obj) < 0) { goto _failed; } \
goto _push
#define push_variable_value(func, base, pos, len) \
if(construct_cb(func)(user, \
(const char*)base, (const char*)pos, len, &obj) < 0) { goto _failed; } \
goto _push
#define again_fixed_trail(_cs, trail_len) \
trail = trail_len; \
cs = _cs; \
goto _fixed_trail_again
#define again_fixed_trail_if_zero(_cs, trail_len, ifzero) \
trail = trail_len; \
if(trail == 0) { goto ifzero; } \
cs = _cs; \
goto _fixed_trail_again
#define start_container(func, count_, ct_) \
if(top >= MSGPACK_EMBED_STACK_SIZE) { ret = -3; goto _end; } \
if(construct_cb(func)(user, count_, &stack[top].obj) < 0) { goto _failed; } \
if((count_) == 0) { obj = stack[top].obj; \
if (construct_cb(func##_end)(user, &obj) < 0) { goto _failed; } \
goto _push; } \
stack[top].ct = ct_; \
stack[top].size = count_; \
stack[top].count = 0; \
++top; \
goto _header_again
#define NEXT_CS(p) ((unsigned int)*p & 0x1f)
#ifdef USE_CASE_RANGE
#define SWITCH_RANGE_BEGIN switch(*p) {
#define SWITCH_RANGE(FROM, TO) case FROM ... TO:
#define SWITCH_RANGE_DEFAULT default:
#define SWITCH_RANGE_END }
#else
#define SWITCH_RANGE_BEGIN { if(0) {
#define SWITCH_RANGE(FROM, TO) } else if(FROM <= *p && *p <= TO) {
#define SWITCH_RANGE_DEFAULT } else {
#define SWITCH_RANGE_END } }
#endif
if(p == pe) { goto _out; }
do {
switch(cs) {
case CS_HEADER:
SWITCH_RANGE_BEGIN
SWITCH_RANGE(0x00, 0x7f) // Positive Fixnum
push_fixed_value(_uint8, *(uint8_t*)p);
SWITCH_RANGE(0xe0, 0xff) // Negative Fixnum
push_fixed_value(_int8, *(int8_t*)p);
SWITCH_RANGE(0xc0, 0xdf) // Variable
switch(*p) {
case 0xc0: // nil
push_simple_value(_nil);
//case 0xc1: // never used
case 0xc2: // false
push_simple_value(_false);
case 0xc3: // true
push_simple_value(_true);
case 0xc4: // bin 8
again_fixed_trail(NEXT_CS(p), 1);
case 0xc5: // bin 16
again_fixed_trail(NEXT_CS(p), 2);
case 0xc6: // bin 32
again_fixed_trail(NEXT_CS(p), 4);
case 0xc7: // ext 8
again_fixed_trail(NEXT_CS(p), 1);
case 0xc8: // ext 16
again_fixed_trail(NEXT_CS(p), 2);
case 0xc9: // ext 32
again_fixed_trail(NEXT_CS(p), 4);
case 0xca: // float
case 0xcb: // double
case 0xcc: // unsigned int 8
case 0xcd: // unsigned int 16
case 0xce: // unsigned int 32
case 0xcf: // unsigned int 64
case 0xd0: // signed int 8
case 0xd1: // signed int 16
case 0xd2: // signed int 32
case 0xd3: // signed int 64
again_fixed_trail(NEXT_CS(p), 1 << (((unsigned int)*p) & 0x03));
case 0xd4: // fixext 1
case 0xd5: // fixext 2
case 0xd6: // fixext 4
case 0xd7: // fixext 8
again_fixed_trail_if_zero(ACS_EXT_VALUE,
(1 << (((unsigned int)*p) & 0x03))+1,
_ext_zero);
case 0xd8: // fixext 16
again_fixed_trail_if_zero(ACS_EXT_VALUE, 16+1, _ext_zero);
case 0xd9: // str 8
again_fixed_trail(NEXT_CS(p), 1);
case 0xda: // raw 16
case 0xdb: // raw 32
case 0xdc: // array 16
case 0xdd: // array 32
case 0xde: // map 16
case 0xdf: // map 32
again_fixed_trail(NEXT_CS(p), 2 << (((unsigned int)*p) & 0x01));
default:
ret = -2;
goto _end;
}
SWITCH_RANGE(0xa0, 0xbf) // FixRaw
again_fixed_trail_if_zero(ACS_RAW_VALUE, ((unsigned int)*p & 0x1f), _raw_zero);
SWITCH_RANGE(0x90, 0x9f) // FixArray
start_container(_array, ((unsigned int)*p) & 0x0f, CT_ARRAY_ITEM);
SWITCH_RANGE(0x80, 0x8f) // FixMap
start_container(_map, ((unsigned int)*p) & 0x0f, CT_MAP_KEY);
SWITCH_RANGE_DEFAULT
ret = -2;
goto _end;
SWITCH_RANGE_END
// end CS_HEADER
_fixed_trail_again:
++p;
default:
if((size_t)(pe - p) < trail) { goto _out; }
n = p; p += trail - 1;
switch(cs) {
case CS_EXT_8:
again_fixed_trail_if_zero(ACS_EXT_VALUE, *(uint8_t*)n+1, _ext_zero);
case CS_EXT_16:
again_fixed_trail_if_zero(ACS_EXT_VALUE,
_msgpack_load16(uint16_t,n)+1,
_ext_zero);
case CS_EXT_32:
again_fixed_trail_if_zero(ACS_EXT_VALUE,
_msgpack_load32(uint32_t,n)+1,
_ext_zero);
case CS_FLOAT: {
double f;
#if PY_VERSION_HEX >= 0x030B00A7
f = PyFloat_Unpack4((const char*)n, 0);
#else
f = _PyFloat_Unpack4((unsigned char*)n, 0);
#endif
push_fixed_value(_float, f); }
case CS_DOUBLE: {
double f;
#if PY_VERSION_HEX >= 0x030B00A7
f = PyFloat_Unpack8((const char*)n, 0);
#else
f = _PyFloat_Unpack8((unsigned char*)n, 0);
#endif
push_fixed_value(_double, f); }
case CS_UINT_8:
push_fixed_value(_uint8, *(uint8_t*)n);
case CS_UINT_16:
push_fixed_value(_uint16, _msgpack_load16(uint16_t,n));
case CS_UINT_32:
push_fixed_value(_uint32, _msgpack_load32(uint32_t,n));
case CS_UINT_64:
push_fixed_value(_uint64, _msgpack_load64(uint64_t,n));
case CS_INT_8:
push_fixed_value(_int8, *(int8_t*)n);
case CS_INT_16:
push_fixed_value(_int16, _msgpack_load16(int16_t,n));
case CS_INT_32:
push_fixed_value(_int32, _msgpack_load32(int32_t,n));
case CS_INT_64:
push_fixed_value(_int64, _msgpack_load64(int64_t,n));
case CS_BIN_8:
again_fixed_trail_if_zero(ACS_BIN_VALUE, *(uint8_t*)n, _bin_zero);
case CS_BIN_16:
again_fixed_trail_if_zero(ACS_BIN_VALUE, _msgpack_load16(uint16_t,n), _bin_zero);
case CS_BIN_32:
again_fixed_trail_if_zero(ACS_BIN_VALUE, _msgpack_load32(uint32_t,n), _bin_zero);
case ACS_BIN_VALUE:
_bin_zero:
push_variable_value(_bin, data, n, trail);
case CS_RAW_8:
again_fixed_trail_if_zero(ACS_RAW_VALUE, *(uint8_t*)n, _raw_zero);
case CS_RAW_16:
again_fixed_trail_if_zero(ACS_RAW_VALUE, _msgpack_load16(uint16_t,n), _raw_zero);
case CS_RAW_32:
again_fixed_trail_if_zero(ACS_RAW_VALUE, _msgpack_load32(uint32_t,n), _raw_zero);
case ACS_RAW_VALUE:
_raw_zero:
push_variable_value(_raw, data, n, trail);
case ACS_EXT_VALUE:
_ext_zero:
push_variable_value(_ext, data, n, trail);
case CS_ARRAY_16:
start_container(_array, _msgpack_load16(uint16_t,n), CT_ARRAY_ITEM);
case CS_ARRAY_32:
/* FIXME security guard */
start_container(_array, _msgpack_load32(uint32_t,n), CT_ARRAY_ITEM);
case CS_MAP_16:
start_container(_map, _msgpack_load16(uint16_t,n), CT_MAP_KEY);
case CS_MAP_32:
/* FIXME security guard */
start_container(_map, _msgpack_load32(uint32_t,n), CT_MAP_KEY);
default:
goto _failed;
}
}
_push:
if(top == 0) { goto _finish; }
c = &stack[top-1];
switch(c->ct) {
case CT_ARRAY_ITEM:
if(construct_cb(_array_item)(user, c->count, &c->obj, obj) < 0) { goto _failed; }
if(++c->count == c->size) {
obj = c->obj;
if (construct_cb(_array_end)(user, &obj) < 0) { goto _failed; }
--top;
/*printf("stack pop %d\n", top);*/
goto _push;
}
goto _header_again;
case CT_MAP_KEY:
c->map_key = obj;
c->ct = CT_MAP_VALUE;
goto _header_again;
case CT_MAP_VALUE:
if(construct_cb(_map_item)(user, c->count, &c->obj, c->map_key, obj) < 0) { goto _failed; }
if(++c->count == c->size) {
obj = c->obj;
if (construct_cb(_map_end)(user, &obj) < 0) { goto _failed; }
--top;
/*printf("stack pop %d\n", top);*/
goto _push;
}
c->ct = CT_MAP_KEY;
goto _header_again;
default:
goto _failed;
}
_header_again:
cs = CS_HEADER;
++p;
} while(p != pe);
goto _out;
_finish:
if (!construct)
unpack_callback_nil(user, &obj);
stack[0].obj = obj;
++p;
ret = 1;
/*printf("-- finish --\n"); */
goto _end;
_failed:
/*printf("** FAILED **\n"); */
ret = -1;
goto _end;
_out:
ret = 0;
goto _end;
_end:
ctx->cs = cs;
ctx->trail = trail;
ctx->top = top;
*off = p - (const unsigned char*)data;
return ret;
#undef construct_cb
}
#undef NEXT_CS
#undef SWITCH_RANGE_BEGIN
#undef SWITCH_RANGE
#undef SWITCH_RANGE_DEFAULT
#undef SWITCH_RANGE_END
#undef push_simple_value
#undef push_fixed_value
#undef push_variable_value
#undef again_fixed_trail
#undef again_fixed_trail_if_zero
#undef start_container
static int unpack_construct(unpack_context *ctx, const char *data, Py_ssize_t len, Py_ssize_t *off) {
return unpack_execute(1, ctx, data, len, off);
}
static int unpack_skip(unpack_context *ctx, const char *data, Py_ssize_t len, Py_ssize_t *off) {
return unpack_execute(0, ctx, data, len, off);
}
#define unpack_container_header read_array_header
#define fixed_offset 0x90
#define var_offset 0xdc
#include "unpack_container_header.h"
#undef unpack_container_header
#undef fixed_offset
#undef var_offset
#define unpack_container_header read_map_header
#define fixed_offset 0x80
#define var_offset 0xde
#include "unpack_container_header.h"
#undef unpack_container_header
#undef fixed_offset
#undef var_offset
/* vim: set ts=4 sw=4 sts=4 expandtab */
srsly-release-v2.5.1/srsly/msgpack/util.py 0000664 0000000 0000000 00000000455 14742310675 0020644 0 ustar 00root root 0000000 0000000 from __future__ import unicode_literals
try:
unicode
except NameError:
unicode = str
def ensure_bytes(string):
"""Ensure a string is returned as a bytes object, encoded as utf8."""
if isinstance(string, unicode):
return string.encode("utf8")
else:
return string
srsly-release-v2.5.1/srsly/ruamel_yaml/ 0000775 0000000 0000000 00000000000 14742310675 0020173 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/ruamel_yaml/LICENSE 0000775 0000000 0000000 00000002141 14742310675 0021201 0 ustar 00root root 0000000 0000000 The MIT License (MIT)
Copyright (c) 2014-2020 Anthon van der Neut, Ruamel bvba
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
srsly-release-v2.5.1/srsly/ruamel_yaml/__init__.py 0000775 0000000 0000000 00000000150 14742310675 0022303 0 ustar 00root root 0000000 0000000 __with_libyaml__ = False
from .main import * # NOQA
version_info = (0, 16, 7)
__version__ = "0.16.7"
srsly-release-v2.5.1/srsly/ruamel_yaml/anchor.py 0000775 0000000 0000000 00000000765 14742310675 0022032 0 ustar 00root root 0000000 0000000
if False: # MYPY
from typing import Any, Dict, Optional, List, Union, Optional, Iterator # NOQA
anchor_attrib = '_yaml_anchor'
class Anchor(object):
__slots__ = 'value', 'always_dump'
attrib = anchor_attrib
def __init__(self):
# type: () -> None
self.value = None
self.always_dump = False
def __repr__(self):
# type: () -> Any
ad = ', (always dump)' if self.always_dump else ""
return 'Anchor({!r}{})'.format(self.value, ad)
srsly-release-v2.5.1/srsly/ruamel_yaml/comments.py 0000775 0000000 0000000 00000104276 14742310675 0022407 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import, print_function
"""
stuff to deal with comments and formatting on dict/list/ordereddict/set
these are not really related, formatting could be factored out as
a separate base
"""
import sys
import copy
from .compat import ordereddict # type: ignore
from .compat import PY2, string_types, MutableSliceableSequence
from .scalarstring import ScalarString
from .anchor import Anchor
if PY2:
from collections import MutableSet, Sized, Set, Mapping
else:
from collections.abc import MutableSet, Sized, Set, Mapping
if False: # MYPY
from typing import Any, Dict, Optional, List, Union, Optional, Iterator # NOQA
# fmt: off
__all__ = ['CommentedSeq', 'CommentedKeySeq',
'CommentedMap', 'CommentedOrderedMap',
'CommentedSet', 'comment_attrib', 'merge_attrib']
# fmt: on
comment_attrib = "_yaml_comment"
format_attrib = "_yaml_format"
line_col_attrib = "_yaml_line_col"
merge_attrib = "_yaml_merge"
tag_attrib = "_yaml_tag"
class Comment(object):
# sys.getsize tested the Comment objects, __slots__ makes them bigger
# and adding self.end did not matter
__slots__ = "comment", "_items", "_end", "_start"
attrib = comment_attrib
def __init__(self):
# type: () -> None
self.comment = None # [post, [pre]]
# map key (mapping/omap/dict) or index (sequence/list) to a list of
# dict: post_key, pre_key, post_value, pre_value
# list: pre item, post item
self._items = {} # type: Dict[Any, Any]
# self._start = [] # should not put these on first item
self._end = [] # type: List[Any] # end of document comments
def __str__(self):
# type: () -> str
if bool(self._end):
end = ",\n end=" + str(self._end)
else:
end = ""
return "Comment(comment={0},\n items={1}{2})".format(
self.comment, self._items, end
)
@property
def items(self):
# type: () -> Any
return self._items
@property
def end(self):
# type: () -> Any
return self._end
@end.setter
def end(self, value):
# type: (Any) -> None
self._end = value
@property
def start(self):
# type: () -> Any
return self._start
@start.setter
def start(self, value):
# type: (Any) -> None
self._start = value
# to distinguish key from None
def NoComment():
# type: () -> None
pass
class Format(object):
__slots__ = ("_flow_style",)
attrib = format_attrib
def __init__(self):
# type: () -> None
self._flow_style = None # type: Any
def set_flow_style(self):
# type: () -> None
self._flow_style = True
def set_block_style(self):
# type: () -> None
self._flow_style = False
def flow_style(self, default=None):
# type: (Optional[Any]) -> Any
"""if default (the flow_style) is None, the flow style tacked on to
the object explicitly will be taken. If that is None as well the
default flow style rules the format down the line, or the type
of the constituent values (simple -> flow, map/list -> block)"""
if self._flow_style is None:
return default
return self._flow_style
class LineCol(object):
attrib = line_col_attrib
def __init__(self):
# type: () -> None
self.line = None
self.col = None
self.data = None # type: Optional[Dict[Any, Any]]
def add_kv_line_col(self, key, data):
# type: (Any, Any) -> None
if self.data is None:
self.data = {}
self.data[key] = data
def key(self, k):
# type: (Any) -> Any
return self._kv(k, 0, 1)
def value(self, k):
# type: (Any) -> Any
return self._kv(k, 2, 3)
def _kv(self, k, x0, x1):
# type: (Any, Any, Any) -> Any
if self.data is None:
return None
data = self.data[k]
return data[x0], data[x1]
def item(self, idx):
# type: (Any) -> Any
if self.data is None:
return None
return self.data[idx][0], self.data[idx][1]
def add_idx_line_col(self, key, data):
# type: (Any, Any) -> None
if self.data is None:
self.data = {}
self.data[key] = data
class Tag(object):
"""store tag information for roundtripping"""
__slots__ = ("value",)
attrib = tag_attrib
def __init__(self):
# type: () -> None
self.value = None
def __repr__(self):
# type: () -> Any
return "{0.__class__.__name__}({0.value!r})".format(self)
class CommentedBase(object):
@property
def ca(self):
# type: () -> Any
if not hasattr(self, Comment.attrib):
setattr(self, Comment.attrib, Comment())
return getattr(self, Comment.attrib)
def yaml_end_comment_extend(self, comment, clear=False):
# type: (Any, bool) -> None
if comment is None:
return
if clear or self.ca.end is None:
self.ca.end = []
self.ca.end.extend(comment)
def yaml_key_comment_extend(self, key, comment, clear=False):
# type: (Any, Any, bool) -> None
r = self.ca._items.setdefault(key, [None, None, None, None])
if clear or r[1] is None:
if comment[1] is not None:
assert isinstance(comment[1], list)
r[1] = comment[1]
else:
r[1].extend(comment[0])
r[0] = comment[0]
def yaml_value_comment_extend(self, key, comment, clear=False):
# type: (Any, Any, bool) -> None
r = self.ca._items.setdefault(key, [None, None, None, None])
if clear or r[3] is None:
if comment[1] is not None:
assert isinstance(comment[1], list)
r[3] = comment[1]
else:
r[3].extend(comment[0])
r[2] = comment[0]
def yaml_set_start_comment(self, comment, indent=0):
# type: (Any, Any) -> None
"""overwrites any preceding comment lines on an object
expects comment to be without `#` and possible have multiple lines
"""
from .error import CommentMark
from .tokens import CommentToken
pre_comments = self._yaml_get_pre_comment()
if comment[-1] == "\n":
comment = comment[:-1] # strip final newline if there
start_mark = CommentMark(indent)
for com in comment.split("\n"):
pre_comments.append(CommentToken("# " + com + "\n", start_mark, None))
def yaml_set_comment_before_after_key(
self, key, before=None, indent=0, after=None, after_indent=None
):
# type: (Any, Any, Any, Any, Any) -> None
"""
expects comment (before/after) to be without `#` and possible have multiple lines
"""
from srsly.ruamel_yaml.error import CommentMark
from srsly.ruamel_yaml.tokens import CommentToken
def comment_token(s, mark):
# type: (Any, Any) -> Any
# handle empty lines as having no comment
return CommentToken(("# " if s else "") + s + "\n", mark, None)
if after_indent is None:
after_indent = indent + 2
if before and (len(before) > 1) and before[-1] == "\n":
before = before[:-1] # strip final newline if there
if after and after[-1] == "\n":
after = after[:-1] # strip final newline if there
start_mark = CommentMark(indent)
c = self.ca.items.setdefault(key, [None, [], None, None])
if before == "\n":
c[1].append(comment_token("", start_mark))
elif before:
for com in before.split("\n"):
c[1].append(comment_token(com, start_mark))
if after:
start_mark = CommentMark(after_indent)
if c[3] is None:
c[3] = []
for com in after.split("\n"):
c[3].append(comment_token(com, start_mark)) # type: ignore
@property
def fa(self):
# type: () -> Any
"""format attribute
set_flow_style()/set_block_style()"""
if not hasattr(self, Format.attrib):
setattr(self, Format.attrib, Format())
return getattr(self, Format.attrib)
def yaml_add_eol_comment(self, comment, key=NoComment, column=None):
# type: (Any, Optional[Any], Optional[Any]) -> None
"""
there is a problem as eol comments should start with ' #'
(but at the beginning of the line the space doesn't have to be before
the #. The column index is for the # mark
"""
from .tokens import CommentToken
from .error import CommentMark
if column is None:
try:
column = self._yaml_get_column(key)
except AttributeError:
column = 0
if comment[0] != "#":
comment = "# " + comment
if column is None:
if comment[0] == "#":
comment = " " + comment
column = 0
start_mark = CommentMark(column)
ct = [CommentToken(comment, start_mark, None), None]
self._yaml_add_eol_comment(ct, key=key)
@property
def lc(self):
# type: () -> Any
if not hasattr(self, LineCol.attrib):
setattr(self, LineCol.attrib, LineCol())
return getattr(self, LineCol.attrib)
def _yaml_set_line_col(self, line, col):
# type: (Any, Any) -> None
self.lc.line = line
self.lc.col = col
def _yaml_set_kv_line_col(self, key, data):
# type: (Any, Any) -> None
self.lc.add_kv_line_col(key, data)
def _yaml_set_idx_line_col(self, key, data):
# type: (Any, Any) -> None
self.lc.add_idx_line_col(key, data)
@property
def anchor(self):
# type: () -> Any
if not hasattr(self, Anchor.attrib):
setattr(self, Anchor.attrib, Anchor())
return getattr(self, Anchor.attrib)
def yaml_anchor(self):
# type: () -> Any
if not hasattr(self, Anchor.attrib):
return None
return self.anchor
def yaml_set_anchor(self, value, always_dump=False):
# type: (Any, bool) -> None
self.anchor.value = value
self.anchor.always_dump = always_dump
@property
def tag(self):
# type: () -> Any
if not hasattr(self, Tag.attrib):
setattr(self, Tag.attrib, Tag())
return getattr(self, Tag.attrib)
def yaml_set_tag(self, value):
# type: (Any) -> None
self.tag.value = value
def copy_attributes(self, t, memo=None):
# type: (Any, Any) -> None
# fmt: off
for a in [Comment.attrib, Format.attrib, LineCol.attrib, Anchor.attrib,
Tag.attrib, merge_attrib]:
if hasattr(self, a):
if memo is not None:
setattr(t, a, copy.deepcopy(getattr(self, a, memo)))
else:
setattr(t, a, getattr(self, a))
# fmt: on
def _yaml_add_eol_comment(self, comment, key):
# type: (Any, Any) -> None
raise NotImplementedError
def _yaml_get_pre_comment(self):
# type: () -> Any
raise NotImplementedError
def _yaml_get_column(self, key):
# type: (Any) -> Any
raise NotImplementedError
class CommentedSeq(MutableSliceableSequence, list, CommentedBase): # type: ignore
__slots__ = (Comment.attrib, "_lst")
def __init__(self, *args, **kw):
# type: (Any, Any) -> None
list.__init__(self, *args, **kw)
def __getsingleitem__(self, idx):
# type: (Any) -> Any
return list.__getitem__(self, idx)
def __setsingleitem__(self, idx, value):
# type: (Any, Any) -> None
# try to preserve the scalarstring type if setting an existing key to a new value
if idx < len(self):
if (
isinstance(value, string_types)
and not isinstance(value, ScalarString)
and isinstance(self[idx], ScalarString)
):
value = type(self[idx])(value)
list.__setitem__(self, idx, value)
def __delsingleitem__(self, idx=None):
# type: (Any) -> Any
list.__delitem__(self, idx)
self.ca.items.pop(idx, None) # might not be there -> default value
for list_index in sorted(self.ca.items):
if list_index < idx:
continue
self.ca.items[list_index - 1] = self.ca.items.pop(list_index)
def __len__(self):
# type: () -> int
return list.__len__(self)
def insert(self, idx, val):
# type: (Any, Any) -> None
"""the comments after the insertion have to move forward"""
list.insert(self, idx, val)
for list_index in sorted(self.ca.items, reverse=True):
if list_index < idx:
break
self.ca.items[list_index + 1] = self.ca.items.pop(list_index)
def extend(self, val):
# type: (Any) -> None
list.extend(self, val)
def __eq__(self, other):
# type: (Any) -> bool
return list.__eq__(self, other)
def _yaml_add_comment(self, comment, key=NoComment):
# type: (Any, Optional[Any]) -> None
if key is not NoComment:
self.yaml_key_comment_extend(key, comment)
else:
self.ca.comment = comment
def _yaml_add_eol_comment(self, comment, key):
# type: (Any, Any) -> None
self._yaml_add_comment(comment, key=key)
def _yaml_get_columnX(self, key):
# type: (Any) -> Any
return self.ca.items[key][0].start_mark.column
def _yaml_get_column(self, key):
# type: (Any) -> Any
column = None
sel_idx = None
pre, post = key - 1, key + 1
if pre in self.ca.items:
sel_idx = pre
elif post in self.ca.items:
sel_idx = post
else:
# self.ca.items is not ordered
for row_idx, _k1 in enumerate(self):
if row_idx >= key:
break
if row_idx not in self.ca.items:
continue
sel_idx = row_idx
if sel_idx is not None:
column = self._yaml_get_columnX(sel_idx)
return column
def _yaml_get_pre_comment(self):
# type: () -> Any
pre_comments = [] # type: List[Any]
if self.ca.comment is None:
self.ca.comment = [None, pre_comments]
else:
self.ca.comment[1] = pre_comments
return pre_comments
def __deepcopy__(self, memo):
# type: (Any) -> Any
res = self.__class__()
memo[id(self)] = res
for k in self:
res.append(copy.deepcopy(k, memo))
self.copy_attributes(res, memo=memo)
return res
def __add__(self, other):
# type: (Any) -> Any
return list.__add__(self, other)
def sort(self, key=None, reverse=False): # type: ignore
# type: (Any, bool) -> None
if key is None:
tmp_lst = sorted(zip(self, range(len(self))), reverse=reverse)
list.__init__(self, [x[0] for x in tmp_lst])
else:
tmp_lst = sorted(
zip(map(key, list.__iter__(self)), range(len(self))), reverse=reverse
)
list.__init__(self, [list.__getitem__(self, x[1]) for x in tmp_lst])
itm = self.ca.items
self.ca._items = {}
for idx, x in enumerate(tmp_lst):
old_index = x[1]
if old_index in itm:
self.ca.items[idx] = itm[old_index]
def __repr__(self):
# type: () -> Any
return list.__repr__(self)
class CommentedKeySeq(tuple, CommentedBase): # type: ignore
"""This primarily exists to be able to roundtrip keys that are sequences"""
def _yaml_add_comment(self, comment, key=NoComment):
# type: (Any, Optional[Any]) -> None
if key is not NoComment:
self.yaml_key_comment_extend(key, comment)
else:
self.ca.comment = comment
def _yaml_add_eol_comment(self, comment, key):
# type: (Any, Any) -> None
self._yaml_add_comment(comment, key=key)
def _yaml_get_columnX(self, key):
# type: (Any) -> Any
return self.ca.items[key][0].start_mark.column
def _yaml_get_column(self, key):
# type: (Any) -> Any
column = None
sel_idx = None
pre, post = key - 1, key + 1
if pre in self.ca.items:
sel_idx = pre
elif post in self.ca.items:
sel_idx = post
else:
# self.ca.items is not ordered
for row_idx, _k1 in enumerate(self):
if row_idx >= key:
break
if row_idx not in self.ca.items:
continue
sel_idx = row_idx
if sel_idx is not None:
column = self._yaml_get_columnX(sel_idx)
return column
def _yaml_get_pre_comment(self):
# type: () -> Any
pre_comments = [] # type: List[Any]
if self.ca.comment is None:
self.ca.comment = [None, pre_comments]
else:
self.ca.comment[1] = pre_comments
return pre_comments
class CommentedMapView(Sized):
__slots__ = ("_mapping",)
def __init__(self, mapping):
# type: (Any) -> None
self._mapping = mapping
def __len__(self):
# type: () -> int
count = len(self._mapping)
return count
class CommentedMapKeysView(CommentedMapView, Set): # type: ignore
__slots__ = ()
@classmethod
def _from_iterable(self, it):
# type: (Any) -> Any
return set(it)
def __contains__(self, key):
# type: (Any) -> Any
return key in self._mapping
def __iter__(self):
# type: () -> Any # yield from self._mapping # not in py27, pypy
# for x in self._mapping._keys():
for x in self._mapping:
yield x
class CommentedMapItemsView(CommentedMapView, Set): # type: ignore
__slots__ = ()
@classmethod
def _from_iterable(self, it):
# type: (Any) -> Any
return set(it)
def __contains__(self, item):
# type: (Any) -> Any
key, value = item
try:
v = self._mapping[key]
except KeyError:
return False
else:
return v == value
def __iter__(self):
# type: () -> Any
for key in self._mapping._keys():
yield (key, self._mapping[key])
class CommentedMapValuesView(CommentedMapView):
__slots__ = ()
def __contains__(self, value):
# type: (Any) -> Any
for key in self._mapping:
if value == self._mapping[key]:
return True
return False
def __iter__(self):
# type: () -> Any
for key in self._mapping._keys():
yield self._mapping[key]
class CommentedMap(ordereddict, CommentedBase): # type: ignore
__slots__ = (Comment.attrib, "_ok", "_ref")
def __init__(self, *args, **kw):
# type: (Any, Any) -> None
self._ok = set() # type: MutableSet[Any] # own keys
self._ref = [] # type: List[CommentedMap]
ordereddict.__init__(self, *args, **kw)
def _yaml_add_comment(self, comment, key=NoComment, value=NoComment):
# type: (Any, Optional[Any], Optional[Any]) -> None
"""values is set to key to indicate a value attachment of comment"""
if key is not NoComment:
self.yaml_key_comment_extend(key, comment)
return
if value is not NoComment:
self.yaml_value_comment_extend(value, comment)
else:
self.ca.comment = comment
def _yaml_add_eol_comment(self, comment, key):
# type: (Any, Any) -> None
"""add on the value line, with value specified by the key"""
self._yaml_add_comment(comment, value=key)
def _yaml_get_columnX(self, key):
# type: (Any) -> Any
return self.ca.items[key][2].start_mark.column
def _yaml_get_column(self, key):
# type: (Any) -> Any
column = None
sel_idx = None
pre, post, last = None, None, None
for x in self:
if pre is not None and x != key:
post = x
break
if x == key:
pre = last
last = x
if pre in self.ca.items:
sel_idx = pre
elif post in self.ca.items:
sel_idx = post
else:
# self.ca.items is not ordered
for k1 in self:
if k1 >= key:
break
if k1 not in self.ca.items:
continue
sel_idx = k1
if sel_idx is not None:
column = self._yaml_get_columnX(sel_idx)
return column
def _yaml_get_pre_comment(self):
# type: () -> Any
pre_comments = [] # type: List[Any]
if self.ca.comment is None:
self.ca.comment = [None, pre_comments]
else:
self.ca.comment[1] = pre_comments
return pre_comments
def update(self, vals):
# type: (Any) -> None
try:
ordereddict.update(self, vals)
except TypeError:
# probably a dict that is used
for x in vals:
self[x] = vals[x]
try:
self._ok.update(vals.keys()) # type: ignore
except AttributeError:
# assume a list/tuple of two element lists/tuples
for x in vals:
self._ok.add(x[0])
def insert(self, pos, key, value, comment=None):
# type: (Any, Any, Any, Optional[Any]) -> None
"""insert key value into given position
attach comment if provided
"""
ordereddict.insert(self, pos, key, value)
self._ok.add(key)
if comment is not None:
self.yaml_add_eol_comment(comment, key=key)
def mlget(self, key, default=None, list_ok=False):
# type: (Any, Any, Any) -> Any
"""multi-level get that expects dicts within dicts"""
if not isinstance(key, list):
return self.get(key, default)
# assume that the key is a list of recursively accessible dicts
def get_one_level(key_list, level, d):
# type: (Any, Any, Any) -> Any
if not list_ok:
assert isinstance(d, dict)
if level >= len(key_list):
if level > len(key_list):
raise IndexError
return d[key_list[level - 1]]
return get_one_level(key_list, level + 1, d[key_list[level - 1]])
try:
return get_one_level(key, 1, self)
except KeyError:
return default
except (TypeError, IndexError):
if not list_ok:
raise
return default
def __getitem__(self, key):
# type: (Any) -> Any
try:
return ordereddict.__getitem__(self, key)
except KeyError:
for merged in getattr(self, merge_attrib, []):
if key in merged[1]:
return merged[1][key]
raise
def __setitem__(self, key, value):
# type: (Any, Any) -> None
# try to preserve the scalarstring type if setting an existing key to a new value
if key in self:
if (
isinstance(value, string_types)
and not isinstance(value, ScalarString)
and isinstance(self[key], ScalarString)
):
value = type(self[key])(value)
ordereddict.__setitem__(self, key, value)
self._ok.add(key)
def _unmerged_contains(self, key):
# type: (Any) -> Any
if key in self._ok:
return True
return None
def __contains__(self, key):
# type: (Any) -> bool
return bool(ordereddict.__contains__(self, key))
def get(self, key, default=None):
# type: (Any, Any) -> Any
try:
return self.__getitem__(key)
except: # NOQA
return default
def __repr__(self):
# type: () -> Any
return ordereddict.__repr__(self).replace("CommentedMap", "ordereddict")
def non_merged_items(self):
# type: () -> Any
for x in ordereddict.__iter__(self):
if x in self._ok:
yield x, ordereddict.__getitem__(self, x)
def __delitem__(self, key):
# type: (Any) -> None
# for merged in getattr(self, merge_attrib, []):
# if key in merged[1]:
# value = merged[1][key]
# break
# else:
# # not found in merged in stuff
# ordereddict.__delitem__(self, key)
# for referer in self._ref:
# referer.update_key_value(key)
# return
#
# ordereddict.__setitem__(self, key, value) # merge might have different value
# self._ok.discard(key)
self._ok.discard(key)
ordereddict.__delitem__(self, key)
for referer in self._ref:
referer.update_key_value(key)
def __iter__(self):
# type: () -> Any
for x in ordereddict.__iter__(self):
yield x
def _keys(self):
# type: () -> Any
for x in ordereddict.__iter__(self):
yield x
def __len__(self):
# type: () -> int
return int(ordereddict.__len__(self))
def __eq__(self, other):
# type: (Any) -> bool
return bool(dict(self) == other)
if PY2:
def keys(self):
# type: () -> Any
return list(self._keys())
def iterkeys(self):
# type: () -> Any
return self._keys()
def viewkeys(self):
# type: () -> Any
return CommentedMapKeysView(self)
else:
def keys(self):
# type: () -> Any
return CommentedMapKeysView(self)
if PY2:
def _values(self):
# type: () -> Any
for x in ordereddict.__iter__(self):
yield ordereddict.__getitem__(self, x)
def values(self):
# type: () -> Any
return list(self._values())
def itervalues(self):
# type: () -> Any
return self._values()
def viewvalues(self):
# type: () -> Any
return CommentedMapValuesView(self)
else:
def values(self):
# type: () -> Any
return CommentedMapValuesView(self)
def _items(self):
# type: () -> Any
for x in ordereddict.__iter__(self):
yield x, ordereddict.__getitem__(self, x)
if PY2:
def items(self):
# type: () -> Any
return list(self._items())
def iteritems(self):
# type: () -> Any
return self._items()
def viewitems(self):
# type: () -> Any
return CommentedMapItemsView(self)
else:
def items(self):
# type: () -> Any
return CommentedMapItemsView(self)
@property
def merge(self):
# type: () -> Any
if not hasattr(self, merge_attrib):
setattr(self, merge_attrib, [])
return getattr(self, merge_attrib)
def copy(self):
# type: () -> Any
x = type(self)() # update doesn't work
for k, v in self._items():
x[k] = v
self.copy_attributes(x)
return x
def add_referent(self, cm):
# type: (Any) -> None
if cm not in self._ref:
self._ref.append(cm)
def add_yaml_merge(self, value):
# type: (Any) -> None
for v in value:
v[1].add_referent(self)
for k, v in v[1].items():
if ordereddict.__contains__(self, k):
continue
ordereddict.__setitem__(self, k, v)
self.merge.extend(value)
def update_key_value(self, key):
# type: (Any) -> None
if key in self._ok:
return
for v in self.merge:
if key in v[1]:
ordereddict.__setitem__(self, key, v[1][key])
return
ordereddict.__delitem__(self, key)
def __deepcopy__(self, memo):
# type: (Any) -> Any
res = self.__class__()
memo[id(self)] = res
for k in self:
res[k] = copy.deepcopy(self[k], memo)
self.copy_attributes(res, memo=memo)
return res
# based on brownie mappings
@classmethod # type: ignore
def raise_immutable(cls, *args, **kwargs):
# type: (Any, *Any, **Any) -> None
raise TypeError("{} objects are immutable".format(cls.__name__))
class CommentedKeyMap(CommentedBase, Mapping): # type: ignore
__slots__ = Comment.attrib, "_od"
"""This primarily exists to be able to roundtrip keys that are mappings"""
def __init__(self, *args, **kw):
# type: (Any, Any) -> None
if hasattr(self, "_od"):
raise_immutable(self)
try:
self._od = ordereddict(*args, **kw)
except TypeError:
if PY2:
self._od = ordereddict(args[0].items())
else:
raise
__delitem__ = (
__setitem__
) = clear = pop = popitem = setdefault = update = raise_immutable
# need to implement __getitem__, __iter__ and __len__
def __getitem__(self, index):
# type: (Any) -> Any
return self._od[index]
def __iter__(self):
# type: () -> Iterator[Any]
for x in self._od.__iter__():
yield x
def __len__(self):
# type: () -> int
return len(self._od)
def __hash__(self):
# type: () -> Any
return hash(tuple(self.items()))
def __repr__(self):
# type: () -> Any
if not hasattr(self, merge_attrib):
return self._od.__repr__()
return "ordereddict(" + repr(list(self._od.items())) + ")"
@classmethod
def fromkeys(keys, v=None):
# type: (Any, Any) -> Any
return CommentedKeyMap(dict.fromkeys(keys, v))
def _yaml_add_comment(self, comment, key=NoComment):
# type: (Any, Optional[Any]) -> None
if key is not NoComment:
self.yaml_key_comment_extend(key, comment)
else:
self.ca.comment = comment
def _yaml_add_eol_comment(self, comment, key):
# type: (Any, Any) -> None
self._yaml_add_comment(comment, key=key)
def _yaml_get_columnX(self, key):
# type: (Any) -> Any
return self.ca.items[key][0].start_mark.column
def _yaml_get_column(self, key):
# type: (Any) -> Any
column = None
sel_idx = None
pre, post = key - 1, key + 1
if pre in self.ca.items:
sel_idx = pre
elif post in self.ca.items:
sel_idx = post
else:
# self.ca.items is not ordered
for row_idx, _k1 in enumerate(self):
if row_idx >= key:
break
if row_idx not in self.ca.items:
continue
sel_idx = row_idx
if sel_idx is not None:
column = self._yaml_get_columnX(sel_idx)
return column
def _yaml_get_pre_comment(self):
# type: () -> Any
pre_comments = [] # type: List[Any]
if self.ca.comment is None:
self.ca.comment = [None, pre_comments]
else:
self.ca.comment[1] = pre_comments
return pre_comments
class CommentedOrderedMap(CommentedMap):
__slots__ = (Comment.attrib,)
class CommentedSet(MutableSet, CommentedBase): # type: ignore # NOQA
__slots__ = Comment.attrib, "odict"
def __init__(self, values=None):
# type: (Any) -> None
self.odict = ordereddict()
MutableSet.__init__(self)
if values is not None:
self |= values # type: ignore
def _yaml_add_comment(self, comment, key=NoComment, value=NoComment):
# type: (Any, Optional[Any], Optional[Any]) -> None
"""values is set to key to indicate a value attachment of comment"""
if key is not NoComment:
self.yaml_key_comment_extend(key, comment)
return
if value is not NoComment:
self.yaml_value_comment_extend(value, comment)
else:
self.ca.comment = comment
def _yaml_add_eol_comment(self, comment, key):
# type: (Any, Any) -> None
"""add on the value line, with value specified by the key"""
self._yaml_add_comment(comment, value=key)
def add(self, value):
# type: (Any) -> None
"""Add an element."""
self.odict[value] = None
def discard(self, value):
# type: (Any) -> None
"""Remove an element. Do not raise an exception if absent."""
del self.odict[value]
def __contains__(self, x):
# type: (Any) -> Any
return x in self.odict
def __iter__(self):
# type: () -> Any
for x in self.odict:
yield x
def __len__(self):
# type: () -> int
return len(self.odict)
def __repr__(self):
# type: () -> str
return "set({0!r})".format(self.odict.keys())
class TaggedScalar(CommentedBase):
# the value and style attributes are set during roundtrip construction
def __init__(self, value=None, style=None, tag=None):
# type: (Any, Any, Any) -> None
self.value = value
self.style = style
if tag is not None:
self.yaml_set_tag(tag)
def __str__(self):
# type: () -> Any
return self.value
def dump_comments(d, name="", sep=".", out=sys.stdout):
# type: (Any, str, str, Any) -> None
"""
recursively dump comments, all but the toplevel preceded by the path
in dotted form x.0.a
"""
if isinstance(d, dict) and hasattr(d, "ca"):
if name:
sys.stdout.write("{}\n".format(name))
out.write("{}\n".format(d.ca)) # type: ignore
for k in d:
dump_comments(d[k], name=(name + sep + k) if name else k, sep=sep, out=out)
elif isinstance(d, list) and hasattr(d, "ca"):
if name:
sys.stdout.write("{}\n".format(name))
out.write("{}\n".format(d.ca)) # type: ignore
for idx, k in enumerate(d):
dump_comments(
k, name=(name + sep + str(idx)) if name else str(idx), sep=sep, out=out
)
srsly-release-v2.5.1/srsly/ruamel_yaml/compat.py 0000775 0000000 0000000 00000020622 14742310675 0022035 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
# partially from package six by Benjamin Peterson
import sys
import os
import types
import traceback
from abc import abstractmethod
from collections import OrderedDict # type: ignore
# fmt: off
if False: # MYPY
from typing import Any, Dict, Optional, List, Union, BinaryIO, IO, Text, Tuple # NOQA
from typing import Optional # NOQA
# fmt: on
_DEFAULT_YAML_VERSION = (1, 2)
class ordereddict(OrderedDict): # type: ignore
if not hasattr(OrderedDict, "insert"):
def insert(self, pos, key, value):
# type: (int, Any, Any) -> None
if pos >= len(self):
self[key] = value
return
od = ordereddict()
od.update(self)
for k in od:
del self[k]
for index, old_key in enumerate(od):
if pos == index:
self[key] = value
self[old_key] = od[old_key]
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
if PY3:
def utf8(s):
# type: (str) -> str
return s
def to_str(s):
# type: (str) -> str
return s
def to_unicode(s):
# type: (str) -> str
return s
else:
if False:
unicode = str
def utf8(s):
# type: (unicode) -> str
return s.encode("utf-8")
def to_str(s):
# type: (str) -> str
return str(s)
def to_unicode(s):
# type: (str) -> unicode
return unicode(s) # NOQA
if PY3:
string_types = str
integer_types = int
class_types = type
text_type = str
binary_type = bytes
MAXSIZE = sys.maxsize
unichr = chr
import io
StringIO = io.StringIO
BytesIO = io.BytesIO
# have unlimited precision
no_limit_int = int
from collections.abc import (
Hashable,
MutableSequence,
MutableMapping,
Mapping,
) # NOQA
else:
string_types = basestring # NOQA
integer_types = (int, long) # NOQA
class_types = (type, types.ClassType)
text_type = unicode # NOQA
binary_type = str
# to allow importing
unichr = unichr
from StringIO import StringIO as _StringIO
StringIO = _StringIO
import cStringIO
BytesIO = cStringIO.StringIO
# have unlimited precision
no_limit_int = long # NOQA not available on Python 3
from collections import Hashable, MutableSequence, MutableMapping, Mapping # NOQA
if False: # MYPY
# StreamType = Union[BinaryIO, IO[str], IO[unicode], StringIO]
# StreamType = Union[BinaryIO, IO[str], StringIO] # type: ignore
StreamType = Any
StreamTextType = StreamType # Union[Text, StreamType]
VersionType = Union[List[int], str, Tuple[int, int]]
if PY3:
builtins_module = "builtins"
else:
builtins_module = "__builtin__"
UNICODE_SIZE = 4 if sys.maxunicode > 65535 else 2
def with_metaclass(meta, *bases):
# type: (Any, Any) -> Any
"""Create a base class with a metaclass."""
return meta("NewBase", bases, {})
DBG_TOKEN = 1
DBG_EVENT = 2
DBG_NODE = 4
_debug = None # type: Optional[int]
if "RUAMELDEBUG" in os.environ:
_debugx = os.environ.get("RUAMELDEBUG")
if _debugx is None:
_debug = 0
else:
_debug = int(_debugx)
if bool(_debug):
class ObjectCounter(object):
def __init__(self):
# type: () -> None
self.map = {} # type: Dict[Any, Any]
def __call__(self, k):
# type: (Any) -> None
self.map[k] = self.map.get(k, 0) + 1
def dump(self):
# type: () -> None
for k in sorted(self.map):
sys.stdout.write("{} -> {}".format(k, self.map[k]))
object_counter = ObjectCounter()
# used from yaml util when testing
def dbg(val=None):
# type: (Any) -> Any
global _debug
if _debug is None:
# set to true or false
_debugx = os.environ.get("YAMLDEBUG")
if _debugx is None:
_debug = 0
else:
_debug = int(_debugx)
if val is None:
return _debug
return _debug & val
class Nprint(object):
def __init__(self, file_name=None):
# type: (Any) -> None
self._max_print = None # type: Any
self._count = None # type: Any
self._file_name = file_name
def __call__(self, *args, **kw):
# type: (Any, Any) -> None
if not bool(_debug):
return
out = sys.stdout if self._file_name is None else open(self._file_name, "a")
dbgprint = print # to fool checking for print statements by dv utility
kw1 = kw.copy()
kw1["file"] = out
dbgprint(*args, **kw1)
out.flush()
if self._max_print is not None:
if self._count is None:
self._count = self._max_print
self._count -= 1
if self._count == 0:
dbgprint("forced exit\n")
traceback.print_stack()
out.flush()
sys.exit(0)
if self._file_name:
out.close()
def set_max_print(self, i):
# type: (int) -> None
self._max_print = i
self._count = None
nprint = Nprint()
nprintf = Nprint("/var/tmp/srsly.ruamel_yaml.log")
# char checkers following production rules
def check_namespace_char(ch):
# type: (Any) -> bool
if u"\x21" <= ch <= u"\x7E": # ! to ~
return True
if u"\xA0" <= ch <= u"\uD7FF":
return True
if (u"\uE000" <= ch <= u"\uFFFD") and ch != u"\uFEFF": # excl. byte order mark
return True
if u"\U00010000" <= ch <= u"\U0010FFFF":
return True
return False
def check_anchorname_char(ch):
# type: (Any) -> bool
if ch in u",[]{}":
return False
return check_namespace_char(ch)
def version_tnf(t1, t2=None):
# type: (Any, Any) -> Any
"""
return True if srsly.ruamel_yaml version_info < t1, None if t2 is specified and bigger else False
"""
from srsly.ruamel_yaml import version_info # NOQA
if version_info < t1:
return True
if t2 is not None and version_info < t2:
return None
return False
class MutableSliceableSequence(MutableSequence): # type: ignore
__slots__ = ()
def __getitem__(self, index):
# type: (Any) -> Any
if not isinstance(index, slice):
return self.__getsingleitem__(index)
return type(self)(
[self[i] for i in range(*index.indices(len(self)))]
) # type: ignore
def __setitem__(self, index, value):
# type: (Any, Any) -> None
if not isinstance(index, slice):
return self.__setsingleitem__(index, value)
assert iter(value)
# nprint(index.start, index.stop, index.step, index.indices(len(self)))
if index.step is None:
del self[index.start : index.stop]
for elem in reversed(value):
self.insert(0 if index.start is None else index.start, elem)
else:
range_parms = index.indices(len(self))
nr_assigned_items = (range_parms[1] - range_parms[0] - 1) // range_parms[
2
] + 1
# need to test before changing, in case TypeError is caught
if nr_assigned_items < len(value):
raise TypeError(
"too many elements in value {} < {}".format(
nr_assigned_items, len(value)
)
)
elif nr_assigned_items > len(value):
raise TypeError(
"not enough elements in value {} > {}".format(
nr_assigned_items, len(value)
)
)
for idx, i in enumerate(range(*range_parms)):
self[i] = value[idx]
def __delitem__(self, index):
# type: (Any) -> None
if not isinstance(index, slice):
return self.__delsingleitem__(index)
# nprint(index.start, index.stop, index.step, index.indices(len(self)))
for i in reversed(range(*index.indices(len(self)))):
del self[i]
@abstractmethod
def __getsingleitem__(self, index):
# type: (Any) -> Any
raise IndexError
@abstractmethod
def __setsingleitem__(self, index, value):
# type: (Any, Any) -> None
raise IndexError
@abstractmethod
def __delsingleitem__(self, index):
# type: (Any) -> None
raise IndexError
srsly-release-v2.5.1/srsly/ruamel_yaml/composer.py 0000775 0000000 0000000 00000020243 14742310675 0022400 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import, print_function
import warnings
from .error import MarkedYAMLError, ReusedAnchorWarning
from .compat import utf8, nprint, nprintf # NOQA
from .events import (
StreamStartEvent,
StreamEndEvent,
MappingStartEvent,
MappingEndEvent,
SequenceStartEvent,
SequenceEndEvent,
AliasEvent,
ScalarEvent,
)
from .nodes import MappingNode, ScalarNode, SequenceNode
if False: # MYPY
from typing import Any, Dict, Optional, List # NOQA
__all__ = ["Composer", "ComposerError"]
class ComposerError(MarkedYAMLError):
pass
class Composer(object):
def __init__(self, loader=None):
# type: (Any) -> None
self.loader = loader
if self.loader is not None and getattr(self.loader, "_composer", None) is None:
self.loader._composer = self
self.anchors = {} # type: Dict[Any, Any]
@property
def parser(self):
# type: () -> Any
if hasattr(self.loader, "typ"):
self.loader.parser
return self.loader._parser
@property
def resolver(self):
# type: () -> Any
# assert self.loader._resolver is not None
if hasattr(self.loader, "typ"):
self.loader.resolver
return self.loader._resolver
def check_node(self):
# type: () -> Any
# Drop the STREAM-START event.
if self.parser.check_event(StreamStartEvent):
self.parser.get_event()
# If there are more documents available?
return not self.parser.check_event(StreamEndEvent)
def get_node(self):
# type: () -> Any
# Get the root node of the next document.
if not self.parser.check_event(StreamEndEvent):
return self.compose_document()
def get_single_node(self):
# type: () -> Any
# Drop the STREAM-START event.
self.parser.get_event()
# Compose a document if the stream is not empty.
document = None # type: Any
if not self.parser.check_event(StreamEndEvent):
document = self.compose_document()
# Ensure that the stream contains no more documents.
if not self.parser.check_event(StreamEndEvent):
event = self.parser.get_event()
raise ComposerError(
"expected a single document in the stream",
document.start_mark,
"but found another document",
event.start_mark,
)
# Drop the STREAM-END event.
self.parser.get_event()
return document
def compose_document(self):
# type: (Any) -> Any
# Drop the DOCUMENT-START event.
self.parser.get_event()
# Compose the root node.
node = self.compose_node(None, None)
# Drop the DOCUMENT-END event.
self.parser.get_event()
self.anchors = {}
return node
def compose_node(self, parent, index):
# type: (Any, Any) -> Any
if self.parser.check_event(AliasEvent):
event = self.parser.get_event()
alias = event.anchor
if alias not in self.anchors:
raise ComposerError(
None,
None,
"found undefined alias %r" % utf8(alias),
event.start_mark,
)
return self.anchors[alias]
event = self.parser.peek_event()
anchor = event.anchor
if anchor is not None: # have an anchor
if anchor in self.anchors:
# raise ComposerError(
# "found duplicate anchor %r; first occurrence"
# % utf8(anchor), self.anchors[anchor].start_mark,
# "second occurrence", event.start_mark)
ws = (
"\nfound duplicate anchor {!r}\nfirst occurrence {}\nsecond occurrence "
"{}".format(
(anchor), self.anchors[anchor].start_mark, event.start_mark
)
)
warnings.warn(ws, ReusedAnchorWarning)
self.resolver.descend_resolver(parent, index)
if self.parser.check_event(ScalarEvent):
node = self.compose_scalar_node(anchor)
elif self.parser.check_event(SequenceStartEvent):
node = self.compose_sequence_node(anchor)
elif self.parser.check_event(MappingStartEvent):
node = self.compose_mapping_node(anchor)
self.resolver.ascend_resolver()
return node
def compose_scalar_node(self, anchor):
# type: (Any) -> Any
event = self.parser.get_event()
tag = event.tag
if tag is None or tag == u"!":
tag = self.resolver.resolve(ScalarNode, event.value, event.implicit)
node = ScalarNode(
tag,
event.value,
event.start_mark,
event.end_mark,
style=event.style,
comment=event.comment,
anchor=anchor,
)
if anchor is not None:
self.anchors[anchor] = node
return node
def compose_sequence_node(self, anchor):
# type: (Any) -> Any
start_event = self.parser.get_event()
tag = start_event.tag
if tag is None or tag == u"!":
tag = self.resolver.resolve(SequenceNode, None, start_event.implicit)
node = SequenceNode(
tag,
[],
start_event.start_mark,
None,
flow_style=start_event.flow_style,
comment=start_event.comment,
anchor=anchor,
)
if anchor is not None:
self.anchors[anchor] = node
index = 0
while not self.parser.check_event(SequenceEndEvent):
node.value.append(self.compose_node(node, index))
index += 1
end_event = self.parser.get_event()
if node.flow_style is True and end_event.comment is not None:
if node.comment is not None:
nprint(
"Warning: unexpected end_event commment in sequence "
"node {}".format(node.flow_style)
)
node.comment = end_event.comment
node.end_mark = end_event.end_mark
self.check_end_doc_comment(end_event, node)
return node
def compose_mapping_node(self, anchor):
# type: (Any) -> Any
start_event = self.parser.get_event()
tag = start_event.tag
if tag is None or tag == u"!":
tag = self.resolver.resolve(MappingNode, None, start_event.implicit)
node = MappingNode(
tag,
[],
start_event.start_mark,
None,
flow_style=start_event.flow_style,
comment=start_event.comment,
anchor=anchor,
)
if anchor is not None:
self.anchors[anchor] = node
while not self.parser.check_event(MappingEndEvent):
# key_event = self.parser.peek_event()
item_key = self.compose_node(node, None)
# if item_key in node.value:
# raise ComposerError("while composing a mapping",
# start_event.start_mark,
# "found duplicate key", key_event.start_mark)
item_value = self.compose_node(node, item_key)
# node.value[item_key] = item_value
node.value.append((item_key, item_value))
end_event = self.parser.get_event()
if node.flow_style is True and end_event.comment is not None:
node.comment = end_event.comment
node.end_mark = end_event.end_mark
self.check_end_doc_comment(end_event, node)
return node
def check_end_doc_comment(self, end_event, node):
# type: (Any, Any) -> None
if end_event.comment and end_event.comment[1]:
# pre comments on an end_event, no following to move to
if node.comment is None:
node.comment = [None, None]
assert not isinstance(node, ScalarEvent)
# this is a post comment on a mapping node, add as third element
# in the list
node.comment.append(end_event.comment[1])
end_event.comment[1] = None
srsly-release-v2.5.1/srsly/ruamel_yaml/configobjwalker.py 0000775 0000000 0000000 00000000537 14742310675 0023723 0 ustar 00root root 0000000 0000000 # coding: utf-8
import warnings
from .util import configobj_walker as new_configobj_walker
if False: # MYPY
from typing import Any # NOQA
def configobj_walker(cfg):
# type: (Any) -> Any
warnings.warn(
"configobj_walker has moved to srsly.ruamel_yaml.util, please update your code"
)
return new_configobj_walker(cfg)
srsly-release-v2.5.1/srsly/ruamel_yaml/constructor.py 0000775 0000000 0000000 00000176442 14742310675 0023153 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division
import datetime
import base64
import binascii
import re
import sys
import types
import warnings
# fmt: off
from .error import (MarkedYAMLError, MarkedYAMLFutureWarning,
MantissaNoDotYAML1_1Warning)
from .nodes import * # NOQA
from .nodes import (SequenceNode, MappingNode, ScalarNode)
from .compat import (utf8, builtins_module, to_str, PY2, PY3, # NOQA
text_type, nprint, nprintf, version_tnf)
from .compat import ordereddict, Hashable, MutableSequence # type: ignore
from .compat import MutableMapping # type: ignore
from .comments import * # NOQA
from .comments import (CommentedMap, CommentedOrderedMap, CommentedSet,
CommentedKeySeq, CommentedSeq, TaggedScalar,
CommentedKeyMap)
from .scalarstring import (SingleQuotedScalarString, DoubleQuotedScalarString,
LiteralScalarString, FoldedScalarString,
PlainScalarString, ScalarString,)
from .scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt
from .scalarfloat import ScalarFloat
from .scalarbool import ScalarBoolean
from .timestamp import TimeStamp
from .util import RegExp
if False: # MYPY
from typing import Any, Dict, List, Set, Generator, Union, Optional # NOQA
__all__ = ['BaseConstructor', 'SafeConstructor', 'Constructor',
'ConstructorError', 'RoundTripConstructor']
# fmt: on
class ConstructorError(MarkedYAMLError):
pass
class DuplicateKeyFutureWarning(MarkedYAMLFutureWarning):
pass
class DuplicateKeyError(MarkedYAMLFutureWarning):
pass
class BaseConstructor(object):
yaml_constructors = {} # type: Dict[Any, Any]
yaml_multi_constructors = {} # type: Dict[Any, Any]
def __init__(self, preserve_quotes=None, loader=None):
# type: (Optional[bool], Any) -> None
self.loader = loader
if (
self.loader is not None
and getattr(self.loader, "_constructor", None) is None
):
self.loader._constructor = self
self.loader = loader
self.yaml_base_dict_type = dict
self.yaml_base_list_type = list
self.constructed_objects = {} # type: Dict[Any, Any]
self.recursive_objects = {} # type: Dict[Any, Any]
self.state_generators = [] # type: List[Any]
self.deep_construct = False
self._preserve_quotes = preserve_quotes
self.allow_duplicate_keys = version_tnf((0, 15, 1), (0, 16))
@property
def composer(self):
# type: () -> Any
if hasattr(self.loader, "typ"):
return self.loader.composer
try:
return self.loader._composer
except AttributeError:
sys.stdout.write("slt {}\n".format(type(self)))
sys.stdout.write("slc {}\n".format(self.loader._composer))
sys.stdout.write("{}\n".format(dir(self)))
raise
@property
def resolver(self):
# type: () -> Any
if hasattr(self.loader, "typ"):
return self.loader.resolver
return self.loader._resolver
def check_data(self):
# type: () -> Any
# If there are more documents available?
return self.composer.check_node()
def get_data(self):
# type: () -> Any
# Construct and return the next document.
if self.composer.check_node():
return self.construct_document(self.composer.get_node())
def get_single_data(self):
# type: () -> Any
# Ensure that the stream contains a single document and construct it.
node = self.composer.get_single_node()
if node is not None:
return self.construct_document(node)
return None
def construct_document(self, node):
# type: (Any) -> Any
data = self.construct_object(node)
while bool(self.state_generators):
state_generators = self.state_generators
self.state_generators = []
for generator in state_generators:
for _dummy in generator:
pass
self.constructed_objects = {}
self.recursive_objects = {}
self.deep_construct = False
return data
def construct_object(self, node, deep=False):
# type: (Any, bool) -> Any
"""deep is True when creating an object/mapping recursively,
in that case want the underlying elements available during construction
"""
if node in self.constructed_objects:
return self.constructed_objects[node]
if deep:
old_deep = self.deep_construct
self.deep_construct = True
if node in self.recursive_objects:
return self.recursive_objects[node]
# raise ConstructorError(
# None, None, 'found unconstructable recursive node', node.start_mark
# )
self.recursive_objects[node] = None
data = self.construct_non_recursive_object(node)
self.constructed_objects[node] = data
del self.recursive_objects[node]
if deep:
self.deep_construct = old_deep
return data
def construct_non_recursive_object(self, node, tag=None):
# type: (Any, Optional[str]) -> Any
constructor = None # type: Any
tag_suffix = None
if tag is None:
tag = node.tag
if tag in self.yaml_constructors:
constructor = self.yaml_constructors[tag]
else:
for tag_prefix in self.yaml_multi_constructors:
if tag.startswith(tag_prefix):
tag_suffix = tag[len(tag_prefix) :]
constructor = self.yaml_multi_constructors[tag_prefix]
break
else:
if None in self.yaml_multi_constructors:
tag_suffix = tag
constructor = self.yaml_multi_constructors[None]
elif None in self.yaml_constructors:
constructor = self.yaml_constructors[None]
elif isinstance(node, ScalarNode):
constructor = self.__class__.construct_scalar
elif isinstance(node, SequenceNode):
constructor = self.__class__.construct_sequence
elif isinstance(node, MappingNode):
constructor = self.__class__.construct_mapping
if tag_suffix is None:
data = constructor(self, node)
else:
data = constructor(self, tag_suffix, node)
if isinstance(data, types.GeneratorType):
generator = data
data = next(generator)
if self.deep_construct:
for _dummy in generator:
pass
else:
self.state_generators.append(generator)
return data
def construct_scalar(self, node):
# type: (Any) -> Any
if not isinstance(node, ScalarNode):
raise ConstructorError(
None,
None,
"expected a scalar node, but found %s" % node.id,
node.start_mark,
)
return node.value
def construct_sequence(self, node, deep=False):
# type: (Any, bool) -> Any
"""deep is True when creating an object/mapping recursively,
in that case want the underlying elements available during construction
"""
if not isinstance(node, SequenceNode):
raise ConstructorError(
None,
None,
"expected a sequence node, but found %s" % node.id,
node.start_mark,
)
return [self.construct_object(child, deep=deep) for child in node.value]
def construct_mapping(self, node, deep=False):
# type: (Any, bool) -> Any
"""deep is True when creating an object/mapping recursively,
in that case want the underlying elements available during construction
"""
if not isinstance(node, MappingNode):
raise ConstructorError(
None,
None,
"expected a mapping node, but found %s" % node.id,
node.start_mark,
)
total_mapping = self.yaml_base_dict_type()
if getattr(node, "merge", None) is not None:
todo = [(node.merge, False), (node.value, False)]
else:
todo = [(node.value, True)]
for values, check in todo:
mapping = self.yaml_base_dict_type() # type: Dict[Any, Any]
for key_node, value_node in values:
# keys can be list -> deep
key = self.construct_object(key_node, deep=True)
# lists are not hashable, but tuples are
if not isinstance(key, Hashable):
if isinstance(key, list):
key = tuple(key)
if PY2:
try:
hash(key)
except TypeError as exc:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unacceptable key (%s)" % exc,
key_node.start_mark,
)
else:
if not isinstance(key, Hashable):
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unhashable key",
key_node.start_mark,
)
value = self.construct_object(value_node, deep=deep)
if check:
if self.check_mapping_key(node, key_node, mapping, key, value):
mapping[key] = value
else:
mapping[key] = value
total_mapping.update(mapping)
return total_mapping
def check_mapping_key(self, node, key_node, mapping, key, value):
# type: (Any, Any, Any, Any, Any) -> bool
"""return True if key is unique"""
if key in mapping:
if not self.allow_duplicate_keys:
mk = mapping.get(key)
if PY2:
if isinstance(key, unicode):
key = key.encode("utf-8")
if isinstance(value, unicode):
value = value.encode("utf-8")
if isinstance(mk, unicode):
mk = mk.encode("utf-8")
args = [
"while constructing a mapping",
node.start_mark,
'found duplicate key "{}" with value "{}" '
'(original value: "{}")'.format(key, value, mk),
key_node.start_mark,
"""
To suppress this check see:
http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys
""",
"""\
Duplicate keys will become an error in future releases, and are errors
by default when using the new API.
""",
]
if self.allow_duplicate_keys is None:
warnings.warn(DuplicateKeyFutureWarning(*args))
else:
raise DuplicateKeyError(*args)
return False
return True
def check_set_key(self, node, key_node, setting, key):
# type: (Any, Any, Any, Any, Any) -> None
if key in setting:
if not self.allow_duplicate_keys:
if PY2:
if isinstance(key, unicode):
key = key.encode("utf-8")
args = [
"while constructing a set",
node.start_mark,
'found duplicate key "{}"'.format(key),
key_node.start_mark,
"""
To suppress this check see:
http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys
""",
"""\
Duplicate keys will become an error in future releases, and are errors
by default when using the new API.
""",
]
if self.allow_duplicate_keys is None:
warnings.warn(DuplicateKeyFutureWarning(*args))
else:
raise DuplicateKeyError(*args)
def construct_pairs(self, node, deep=False):
# type: (Any, bool) -> Any
if not isinstance(node, MappingNode):
raise ConstructorError(
None,
None,
"expected a mapping node, but found %s" % node.id,
node.start_mark,
)
pairs = []
for key_node, value_node in node.value:
key = self.construct_object(key_node, deep=deep)
value = self.construct_object(value_node, deep=deep)
pairs.append((key, value))
return pairs
@classmethod
def add_constructor(cls, tag, constructor):
# type: (Any, Any) -> None
if "yaml_constructors" not in cls.__dict__:
cls.yaml_constructors = cls.yaml_constructors.copy()
cls.yaml_constructors[tag] = constructor
@classmethod
def add_multi_constructor(cls, tag_prefix, multi_constructor):
# type: (Any, Any) -> None
if "yaml_multi_constructors" not in cls.__dict__:
cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy()
cls.yaml_multi_constructors[tag_prefix] = multi_constructor
class SafeConstructor(BaseConstructor):
def construct_scalar(self, node):
# type: (Any) -> Any
if isinstance(node, MappingNode):
for key_node, value_node in node.value:
if key_node.tag == u"tag:yaml.org,2002:value":
return self.construct_scalar(value_node)
return BaseConstructor.construct_scalar(self, node)
def flatten_mapping(self, node):
# type: (Any) -> Any
"""
This implements the merge key feature http://yaml.org/type/merge.html
by inserting keys from the merge dict/list of dicts if not yet
available in this node
"""
merge = [] # type: List[Any]
index = 0
while index < len(node.value):
key_node, value_node = node.value[index]
if key_node.tag == u"tag:yaml.org,2002:merge":
if merge: # double << key
if self.allow_duplicate_keys:
del node.value[index]
index += 1
continue
args = [
"while constructing a mapping",
node.start_mark,
'found duplicate key "{}"'.format(key_node.value),
key_node.start_mark,
"""
To suppress this check see:
http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys
""",
"""\
Duplicate keys will become an error in future releases, and are errors
by default when using the new API.
""",
]
if self.allow_duplicate_keys is None:
warnings.warn(DuplicateKeyFutureWarning(*args))
else:
raise DuplicateKeyError(*args)
del node.value[index]
if isinstance(value_node, MappingNode):
self.flatten_mapping(value_node)
merge.extend(value_node.value)
elif isinstance(value_node, SequenceNode):
submerge = []
for subnode in value_node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"expected a mapping for merging, but found %s"
% subnode.id,
subnode.start_mark,
)
self.flatten_mapping(subnode)
submerge.append(subnode.value)
submerge.reverse()
for value in submerge:
merge.extend(value)
else:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"expected a mapping or list of mappings for merging, "
"but found %s" % value_node.id,
value_node.start_mark,
)
elif key_node.tag == u"tag:yaml.org,2002:value":
key_node.tag = u"tag:yaml.org,2002:str"
index += 1
else:
index += 1
if bool(merge):
node.merge = (
merge
) # separate merge keys to be able to update without duplicate
node.value = merge + node.value
def construct_mapping(self, node, deep=False):
# type: (Any, bool) -> Any
"""deep is True when creating an object/mapping recursively,
in that case want the underlying elements available during construction
"""
if isinstance(node, MappingNode):
self.flatten_mapping(node)
return BaseConstructor.construct_mapping(self, node, deep=deep)
def construct_yaml_null(self, node):
# type: (Any) -> Any
self.construct_scalar(node)
return None
# YAML 1.2 spec doesn't mention yes/no etc any more, 1.1 does
bool_values = {
u"yes": True,
u"no": False,
u"y": True,
u"n": False,
u"true": True,
u"false": False,
u"on": True,
u"off": False,
}
def construct_yaml_bool(self, node):
# type: (Any) -> bool
value = self.construct_scalar(node)
return self.bool_values[value.lower()]
def construct_yaml_int(self, node):
# type: (Any) -> int
value_s = to_str(self.construct_scalar(node))
value_s = value_s.replace("_", "")
sign = +1
if value_s[0] == "-":
sign = -1
if value_s[0] in "+-":
value_s = value_s[1:]
if value_s == "0":
return 0
elif value_s.startswith("0b"):
return sign * int(value_s[2:], 2)
elif value_s.startswith("0x"):
return sign * int(value_s[2:], 16)
elif value_s.startswith("0o"):
return sign * int(value_s[2:], 8)
elif self.resolver.processing_version == (1, 1) and value_s[0] == "0":
return sign * int(value_s, 8)
elif self.resolver.processing_version == (1, 1) and ":" in value_s:
digits = [int(part) for part in value_s.split(":")]
digits.reverse()
base = 1
value = 0
for digit in digits:
value += digit * base
base *= 60
return sign * value
else:
return sign * int(value_s)
inf_value = 1e300
while inf_value != inf_value * inf_value:
inf_value *= inf_value
nan_value = -inf_value / inf_value # Trying to make a quiet NaN (like C99).
def construct_yaml_float(self, node):
# type: (Any) -> float
value_so = to_str(self.construct_scalar(node))
value_s = value_so.replace("_", "").lower()
sign = +1
if value_s[0] == "-":
sign = -1
if value_s[0] in "+-":
value_s = value_s[1:]
if value_s == ".inf":
return sign * self.inf_value
elif value_s == ".nan":
return self.nan_value
elif self.resolver.processing_version != (1, 2) and ":" in value_s:
digits = [float(part) for part in value_s.split(":")]
digits.reverse()
base = 1
value = 0.0
for digit in digits:
value += digit * base
base *= 60
return sign * value
else:
if self.resolver.processing_version != (1, 2) and "e" in value_s:
# value_s is lower case independent of input
mantissa, exponent = value_s.split("e")
if "." not in mantissa:
warnings.warn(MantissaNoDotYAML1_1Warning(node, value_so))
return sign * float(value_s)
if PY3:
def construct_yaml_binary(self, node):
# type: (Any) -> Any
try:
value = self.construct_scalar(node).encode("ascii")
except UnicodeEncodeError as exc:
raise ConstructorError(
None,
None,
"failed to convert base64 data into ascii: %s" % exc,
node.start_mark,
)
try:
if hasattr(base64, "decodebytes"):
return base64.decodebytes(value)
else:
return base64.decodestring(value)
except binascii.Error as exc:
raise ConstructorError(
None,
None,
"failed to decode base64 data: %s" % exc,
node.start_mark,
)
else:
def construct_yaml_binary(self, node):
# type: (Any) -> Any
value = self.construct_scalar(node)
try:
return to_str(value).decode("base64")
except (binascii.Error, UnicodeEncodeError) as exc:
raise ConstructorError(
None,
None,
"failed to decode base64 data: %s" % exc,
node.start_mark,
)
timestamp_regexp = RegExp(
u"""^(?P[0-9][0-9][0-9][0-9])
-(?P[0-9][0-9]?)
-(?P[0-9][0-9]?)
(?:((?P[Tt])|[ \\t]+) # explictly not retaining extra spaces
(?P[0-9][0-9]?)
:(?P[0-9][0-9])
:(?P[0-9][0-9])
(?:\\.(?P[0-9]*))?
(?:[ \\t]*(?PZ|(?P[-+])(?P[0-9][0-9]?)
(?::(?P[0-9][0-9]))?))?)?$""",
re.X,
)
def construct_yaml_timestamp(self, node, values=None):
# type: (Any, Any) -> Any
if values is None:
try:
match = self.timestamp_regexp.match(node.value)
except TypeError:
match = None
if match is None:
raise ConstructorError(
None,
None,
'failed to construct timestamp from "{}"'.format(node.value),
node.start_mark,
)
values = match.groupdict()
year = int(values["year"])
month = int(values["month"])
day = int(values["day"])
if not values["hour"]:
return datetime.date(year, month, day)
hour = int(values["hour"])
minute = int(values["minute"])
second = int(values["second"])
fraction = 0
if values["fraction"]:
fraction_s = values["fraction"][:6]
while len(fraction_s) < 6:
fraction_s += "0"
fraction = int(fraction_s)
if len(values["fraction"]) > 6 and int(values["fraction"][6]) > 4:
fraction += 1
delta = None
if values["tz_sign"]:
tz_hour = int(values["tz_hour"])
minutes = values["tz_minute"]
tz_minute = int(minutes) if minutes else 0
delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute)
if values["tz_sign"] == "-":
delta = -delta
# should do something else instead (or hook this up to the preceding if statement
# in reverse
# if delta is None:
# return datetime.datetime(year, month, day, hour, minute, second, fraction)
# return datetime.datetime(year, month, day, hour, minute, second, fraction,
# datetime.timezone.utc)
# the above is not good enough though, should provide tzinfo. In Python3 that is easily
# doable drop that kind of support for Python2 as it has not native tzinfo
data = datetime.datetime(year, month, day, hour, minute, second, fraction)
if delta:
data -= delta
return data
def construct_yaml_omap(self, node):
# type: (Any) -> Any
# Note: we do now check for duplicate keys
omap = ordereddict()
yield omap
if not isinstance(node, SequenceNode):
raise ConstructorError(
"while constructing an ordered map",
node.start_mark,
"expected a sequence, but found %s" % node.id,
node.start_mark,
)
for subnode in node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError(
"while constructing an ordered map",
node.start_mark,
"expected a mapping of length 1, but found %s" % subnode.id,
subnode.start_mark,
)
if len(subnode.value) != 1:
raise ConstructorError(
"while constructing an ordered map",
node.start_mark,
"expected a single mapping item, but found %d items"
% len(subnode.value),
subnode.start_mark,
)
key_node, value_node = subnode.value[0]
key = self.construct_object(key_node)
assert key not in omap
value = self.construct_object(value_node)
omap[key] = value
def construct_yaml_pairs(self, node):
# type: (Any) -> Any
# Note: the same code as `construct_yaml_omap`.
pairs = [] # type: List[Any]
yield pairs
if not isinstance(node, SequenceNode):
raise ConstructorError(
"while constructing pairs",
node.start_mark,
"expected a sequence, but found %s" % node.id,
node.start_mark,
)
for subnode in node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError(
"while constructing pairs",
node.start_mark,
"expected a mapping of length 1, but found %s" % subnode.id,
subnode.start_mark,
)
if len(subnode.value) != 1:
raise ConstructorError(
"while constructing pairs",
node.start_mark,
"expected a single mapping item, but found %d items"
% len(subnode.value),
subnode.start_mark,
)
key_node, value_node = subnode.value[0]
key = self.construct_object(key_node)
value = self.construct_object(value_node)
pairs.append((key, value))
def construct_yaml_set(self, node):
# type: (Any) -> Any
data = set() # type: Set[Any]
yield data
value = self.construct_mapping(node)
data.update(value)
def construct_yaml_str(self, node):
# type: (Any) -> Any
value = self.construct_scalar(node)
if PY3:
return value
try:
return value.encode("ascii")
except UnicodeEncodeError:
return value
def construct_yaml_seq(self, node):
# type: (Any) -> Any
data = self.yaml_base_list_type() # type: List[Any]
yield data
data.extend(self.construct_sequence(node))
def construct_yaml_map(self, node):
# type: (Any) -> Any
data = self.yaml_base_dict_type() # type: Dict[Any, Any]
yield data
value = self.construct_mapping(node)
data.update(value)
def construct_yaml_object(self, node, cls):
# type: (Any, Any) -> Any
data = cls.__new__(cls)
yield data
if hasattr(data, "__setstate__"):
state = self.construct_mapping(node, deep=True)
data.__setstate__(state)
else:
state = self.construct_mapping(node)
data.__dict__.update(state)
def construct_undefined(self, node):
# type: (Any) -> None
raise ConstructorError(
None,
None,
"could not determine a constructor for the tag %r" % utf8(node.tag),
node.start_mark,
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:null", SafeConstructor.construct_yaml_null
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:bool", SafeConstructor.construct_yaml_bool
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:int", SafeConstructor.construct_yaml_int
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:float", SafeConstructor.construct_yaml_float
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:binary", SafeConstructor.construct_yaml_binary
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:timestamp", SafeConstructor.construct_yaml_timestamp
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:omap", SafeConstructor.construct_yaml_omap
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:pairs", SafeConstructor.construct_yaml_pairs
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:set", SafeConstructor.construct_yaml_set
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:str", SafeConstructor.construct_yaml_str
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:seq", SafeConstructor.construct_yaml_seq
)
SafeConstructor.add_constructor(
u"tag:yaml.org,2002:map", SafeConstructor.construct_yaml_map
)
SafeConstructor.add_constructor(None, SafeConstructor.construct_undefined)
if PY2:
class classobj:
pass
class Constructor(SafeConstructor):
def construct_python_str(self, node):
raise ValueError("Unsafe constructor not implemented in this library")
def construct_python_unicode(self, node):
raise ValueError("Unsafe constructor not implemented in this library")
if PY3:
def construct_python_bytes(self, node):
raise ValueError("Unsafe constructor not implemented in this library")
def construct_python_long(self, node):
raise ValueError("Unsafe constructor not implemented in this library")
def construct_python_complex(self, node):
raise ValueError("Unsafe constructor not implemented in this library")
def construct_python_tuple(self, node):
raise ValueError("Unsafe constructor not implemented in this library")
def find_python_module(self, name, mark):
raise ValueError("Unsafe constructor not implemented in this library")
def find_python_name(self, name, mark):
raise ValueError("Unsafe constructor not implemented in this library")
def construct_python_name(self, suffix, node):
raise ValueError("Unsafe constructor not implemented in this library")
def construct_python_module(self, suffix, node):
raise ValueError("Unsafe constructor not implemented in this library")
def make_python_instance(self, suffix, node, args=None, kwds=None, newobj=False):
raise ValueError("Unsafe constructor not implemented in this library")
def set_python_instance_state(self, instance, state):
raise ValueError("Unsafe constructor not implemented in this library")
def construct_python_object(self, suffix, node):
raise ValueError("Unsafe constructor not implemented in this library")
def construct_python_object_apply(self, suffix, node, newobj=False):
raise ValueError("Unsafe constructor not implemented in this library")
def construct_python_object_new(self, suffix, node):
raise ValueError("Unsafe constructor not implemented in this library")
Constructor.add_constructor(
u"tag:yaml.org,2002:python/none", Constructor.construct_yaml_null
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/bool", Constructor.construct_yaml_bool
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/str", Constructor.construct_python_str
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/unicode", Constructor.construct_python_unicode
)
if PY3:
Constructor.add_constructor(
u"tag:yaml.org,2002:python/bytes", Constructor.construct_python_bytes
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/int", Constructor.construct_yaml_int
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/long", Constructor.construct_python_long
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/float", Constructor.construct_yaml_float
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/complex", Constructor.construct_python_complex
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/list", Constructor.construct_yaml_seq
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/tuple", Constructor.construct_python_tuple
)
Constructor.add_constructor(
u"tag:yaml.org,2002:python/dict", Constructor.construct_yaml_map
)
Constructor.add_multi_constructor(
u"tag:yaml.org,2002:python/name:", Constructor.construct_python_name
)
Constructor.add_multi_constructor(
u"tag:yaml.org,2002:python/module:", Constructor.construct_python_module
)
Constructor.add_multi_constructor(
u"tag:yaml.org,2002:python/object:", Constructor.construct_python_object
)
Constructor.add_multi_constructor(
u"tag:yaml.org,2002:python/object/apply:", Constructor.construct_python_object_apply
)
Constructor.add_multi_constructor(
u"tag:yaml.org,2002:python/object/new:", Constructor.construct_python_object_new
)
class RoundTripConstructor(SafeConstructor):
"""need to store the comments on the node itself,
as well as on the items
"""
def construct_scalar(self, node):
# type: (Any) -> Any
if not isinstance(node, ScalarNode):
raise ConstructorError(
None,
None,
"expected a scalar node, but found %s" % node.id,
node.start_mark,
)
if node.style == "|" and isinstance(node.value, text_type):
lss = LiteralScalarString(node.value, anchor=node.anchor)
if node.comment and node.comment[1]:
lss.comment = node.comment[1][0] # type: ignore
return lss
if node.style == ">" and isinstance(node.value, text_type):
fold_positions = [] # type: List[int]
idx = -1
while True:
idx = node.value.find("\a", idx + 1)
if idx < 0:
break
fold_positions.append(idx - len(fold_positions))
fss = FoldedScalarString(node.value.replace("\a", ""), anchor=node.anchor)
if node.comment and node.comment[1]:
fss.comment = node.comment[1][0] # type: ignore
if fold_positions:
fss.fold_pos = fold_positions # type: ignore
return fss
elif bool(self._preserve_quotes) and isinstance(node.value, text_type):
if node.style == "'":
return SingleQuotedScalarString(node.value, anchor=node.anchor)
if node.style == '"':
return DoubleQuotedScalarString(node.value, anchor=node.anchor)
if node.anchor:
return PlainScalarString(node.value, anchor=node.anchor)
return node.value
def construct_yaml_int(self, node):
# type: (Any) -> Any
width = None # type: Any
value_su = to_str(self.construct_scalar(node))
try:
sx = value_su.rstrip("_")
underscore = [len(sx) - sx.rindex("_") - 1, False, False] # type: Any
except ValueError:
underscore = None
except IndexError:
underscore = None
value_s = value_su.replace("_", "")
sign = +1
if value_s[0] == "-":
sign = -1
if value_s[0] in "+-":
value_s = value_s[1:]
if value_s == "0":
return 0
elif value_s.startswith("0b"):
if self.resolver.processing_version > (1, 1) and value_s[2] == "0":
width = len(value_s[2:])
if underscore is not None:
underscore[1] = value_su[2] == "_"
underscore[2] = len(value_su[2:]) > 1 and value_su[-1] == "_"
return BinaryInt(
sign * int(value_s[2:], 2),
width=width,
underscore=underscore,
anchor=node.anchor,
)
elif value_s.startswith("0x"):
# default to lower-case if no a-fA-F in string
if self.resolver.processing_version > (1, 1) and value_s[2] == "0":
width = len(value_s[2:])
hex_fun = HexInt # type: Any
for ch in value_s[2:]:
if ch in "ABCDEF": # first non-digit is capital
hex_fun = HexCapsInt
break
if ch in "abcdef":
break
if underscore is not None:
underscore[1] = value_su[2] == "_"
underscore[2] = len(value_su[2:]) > 1 and value_su[-1] == "_"
return hex_fun(
sign * int(value_s[2:], 16),
width=width,
underscore=underscore,
anchor=node.anchor,
)
elif value_s.startswith("0o"):
if self.resolver.processing_version > (1, 1) and value_s[2] == "0":
width = len(value_s[2:])
if underscore is not None:
underscore[1] = value_su[2] == "_"
underscore[2] = len(value_su[2:]) > 1 and value_su[-1] == "_"
return OctalInt(
sign * int(value_s[2:], 8),
width=width,
underscore=underscore,
anchor=node.anchor,
)
elif self.resolver.processing_version != (1, 2) and value_s[0] == "0":
return sign * int(value_s, 8)
elif self.resolver.processing_version != (1, 2) and ":" in value_s:
digits = [int(part) for part in value_s.split(":")]
digits.reverse()
base = 1
value = 0
for digit in digits:
value += digit * base
base *= 60
return sign * value
elif self.resolver.processing_version > (1, 1) and value_s[0] == "0":
# not an octal, an integer with leading zero(s)
if underscore is not None:
# cannot have a leading underscore
underscore[2] = len(value_su) > 1 and value_su[-1] == "_"
return ScalarInt(
sign * int(value_s), width=len(value_s), underscore=underscore
)
elif underscore:
# cannot have a leading underscore
underscore[2] = len(value_su) > 1 and value_su[-1] == "_"
return ScalarInt(
sign * int(value_s),
width=None,
underscore=underscore,
anchor=node.anchor,
)
elif node.anchor:
return ScalarInt(sign * int(value_s), width=None, anchor=node.anchor)
else:
return sign * int(value_s)
def construct_yaml_float(self, node):
# type: (Any) -> Any
def leading_zeros(v):
# type: (Any) -> int
lead0 = 0
idx = 0
while idx < len(v) and v[idx] in "0.":
if v[idx] == "0":
lead0 += 1
idx += 1
return lead0
# underscore = None
m_sign = False # type: Any
value_so = to_str(self.construct_scalar(node))
value_s = value_so.replace("_", "").lower()
sign = +1
if value_s[0] == "-":
sign = -1
if value_s[0] in "+-":
m_sign = value_s[0]
value_s = value_s[1:]
if value_s == ".inf":
return sign * self.inf_value
if value_s == ".nan":
return self.nan_value
if self.resolver.processing_version != (1, 2) and ":" in value_s:
digits = [float(part) for part in value_s.split(":")]
digits.reverse()
base = 1
value = 0.0
for digit in digits:
value += digit * base
base *= 60
return sign * value
if "e" in value_s:
try:
mantissa, exponent = value_so.split("e")
exp = "e"
except ValueError:
mantissa, exponent = value_so.split("E")
exp = "E"
if self.resolver.processing_version != (1, 2):
# value_s is lower case independent of input
if "." not in mantissa:
warnings.warn(MantissaNoDotYAML1_1Warning(node, value_so))
lead0 = leading_zeros(mantissa)
width = len(mantissa)
prec = mantissa.find(".")
if m_sign:
width -= 1
e_width = len(exponent)
e_sign = exponent[0] in "+-"
# nprint('sf', width, prec, m_sign, exp, e_width, e_sign)
return ScalarFloat(
sign * float(value_s),
width=width,
prec=prec,
m_sign=m_sign,
m_lead0=lead0,
exp=exp,
e_width=e_width,
e_sign=e_sign,
anchor=node.anchor,
)
width = len(value_so)
prec = value_so.index(
"."
) # you can use index, this would not be float without dot
lead0 = leading_zeros(value_so)
return ScalarFloat(
sign * float(value_s),
width=width,
prec=prec,
m_sign=m_sign,
m_lead0=lead0,
anchor=node.anchor,
)
def construct_yaml_str(self, node):
# type: (Any) -> Any
value = self.construct_scalar(node)
if isinstance(value, ScalarString):
return value
if PY3:
return value
try:
return value.encode("ascii")
except AttributeError:
# in case you replace the node dynamically e.g. with a dict
return value
except UnicodeEncodeError:
return value
def construct_rt_sequence(self, node, seqtyp, deep=False):
# type: (Any, Any, bool) -> Any
if not isinstance(node, SequenceNode):
raise ConstructorError(
None,
None,
"expected a sequence node, but found %s" % node.id,
node.start_mark,
)
ret_val = []
if node.comment:
seqtyp._yaml_add_comment(node.comment[:2])
if len(node.comment) > 2:
seqtyp.yaml_end_comment_extend(node.comment[2], clear=True)
if node.anchor:
from .serializer import templated_id
if not templated_id(node.anchor):
seqtyp.yaml_set_anchor(node.anchor)
for idx, child in enumerate(node.value):
if child.comment:
seqtyp._yaml_add_comment(child.comment, key=idx)
child.comment = None # if moved to sequence remove from child
ret_val.append(self.construct_object(child, deep=deep))
seqtyp._yaml_set_idx_line_col(
idx, [child.start_mark.line, child.start_mark.column]
)
return ret_val
def flatten_mapping(self, node):
# type: (Any) -> Any
"""
This implements the merge key feature http://yaml.org/type/merge.html
by inserting keys from the merge dict/list of dicts if not yet
available in this node
"""
def constructed(value_node):
# type: (Any) -> Any
# If the contents of a merge are defined within the
# merge marker, then they won't have been constructed
# yet. But if they were already constructed, we need to use
# the existing object.
if value_node in self.constructed_objects:
value = self.constructed_objects[value_node]
else:
value = self.construct_object(value_node, deep=False)
return value
# merge = []
merge_map_list = [] # type: List[Any]
index = 0
while index < len(node.value):
key_node, value_node = node.value[index]
if key_node.tag == u"tag:yaml.org,2002:merge":
if merge_map_list: # double << key
if self.allow_duplicate_keys:
del node.value[index]
index += 1
continue
args = [
"while constructing a mapping",
node.start_mark,
'found duplicate key "{}"'.format(key_node.value),
key_node.start_mark,
"""
To suppress this check see:
http://yaml.readthedocs.io/en/latest/api.html#duplicate-keys
""",
"""\
Duplicate keys will become an error in future releases, and are errors
by default when using the new API.
""",
]
if self.allow_duplicate_keys is None:
warnings.warn(DuplicateKeyFutureWarning(*args))
else:
raise DuplicateKeyError(*args)
del node.value[index]
if isinstance(value_node, MappingNode):
merge_map_list.append((index, constructed(value_node)))
# self.flatten_mapping(value_node)
# merge.extend(value_node.value)
elif isinstance(value_node, SequenceNode):
# submerge = []
for subnode in value_node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"expected a mapping for merging, but found %s"
% subnode.id,
subnode.start_mark,
)
merge_map_list.append((index, constructed(subnode)))
# self.flatten_mapping(subnode)
# submerge.append(subnode.value)
# submerge.reverse()
# for value in submerge:
# merge.extend(value)
else:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"expected a mapping or list of mappings for merging, "
"but found %s" % value_node.id,
value_node.start_mark,
)
elif key_node.tag == u"tag:yaml.org,2002:value":
key_node.tag = u"tag:yaml.org,2002:str"
index += 1
else:
index += 1
return merge_map_list
# if merge:
# node.value = merge + node.value
def _sentinel(self):
# type: () -> None
pass
def construct_mapping(self, node, maptyp, deep=False): # type: ignore
# type: (Any, Any, bool) -> Any
if not isinstance(node, MappingNode):
raise ConstructorError(
None,
None,
"expected a mapping node, but found %s" % node.id,
node.start_mark,
)
merge_map = self.flatten_mapping(node)
# mapping = {}
if node.comment:
maptyp._yaml_add_comment(node.comment[:2])
if len(node.comment) > 2:
maptyp.yaml_end_comment_extend(node.comment[2], clear=True)
if node.anchor:
from .serializer import templated_id
if not templated_id(node.anchor):
maptyp.yaml_set_anchor(node.anchor)
last_key, last_value = None, self._sentinel
for key_node, value_node in node.value:
# keys can be list -> deep
key = self.construct_object(key_node, deep=True)
# lists are not hashable, but tuples are
if not isinstance(key, Hashable):
if isinstance(key, MutableSequence):
key_s = CommentedKeySeq(key)
if key_node.flow_style is True:
key_s.fa.set_flow_style()
elif key_node.flow_style is False:
key_s.fa.set_block_style()
key = key_s
elif isinstance(key, MutableMapping):
key_m = CommentedKeyMap(key)
if key_node.flow_style is True:
key_m.fa.set_flow_style()
elif key_node.flow_style is False:
key_m.fa.set_block_style()
key = key_m
if PY2:
try:
hash(key)
except TypeError as exc:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unacceptable key (%s)" % exc,
key_node.start_mark,
)
else:
if not isinstance(key, Hashable):
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unhashable key",
key_node.start_mark,
)
value = self.construct_object(value_node, deep=deep)
if self.check_mapping_key(node, key_node, maptyp, key, value):
if (
key_node.comment
and len(key_node.comment) > 4
and key_node.comment[4]
):
if last_value is None:
key_node.comment[0] = key_node.comment.pop(4)
maptyp._yaml_add_comment(key_node.comment, value=last_key)
else:
key_node.comment[2] = key_node.comment.pop(4)
maptyp._yaml_add_comment(key_node.comment, key=key)
key_node.comment = None
if key_node.comment:
maptyp._yaml_add_comment(key_node.comment, key=key)
if value_node.comment:
maptyp._yaml_add_comment(value_node.comment, value=key)
maptyp._yaml_set_kv_line_col(
key,
[
key_node.start_mark.line,
key_node.start_mark.column,
value_node.start_mark.line,
value_node.start_mark.column,
],
)
maptyp[key] = value
last_key, last_value = key, value # could use indexing
# do this last, or <<: before a key will prevent insertion in instances
# of collections.OrderedDict (as they have no __contains__
if merge_map:
maptyp.add_yaml_merge(merge_map)
def construct_setting(self, node, typ, deep=False):
# type: (Any, Any, bool) -> Any
if not isinstance(node, MappingNode):
raise ConstructorError(
None,
None,
"expected a mapping node, but found %s" % node.id,
node.start_mark,
)
if node.comment:
typ._yaml_add_comment(node.comment[:2])
if len(node.comment) > 2:
typ.yaml_end_comment_extend(node.comment[2], clear=True)
if node.anchor:
from .serializer import templated_id
if not templated_id(node.anchor):
typ.yaml_set_anchor(node.anchor)
for key_node, value_node in node.value:
# keys can be list -> deep
key = self.construct_object(key_node, deep=True)
# lists are not hashable, but tuples are
if not isinstance(key, Hashable):
if isinstance(key, list):
key = tuple(key)
if PY2:
try:
hash(key)
except TypeError as exc:
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unacceptable key (%s)" % exc,
key_node.start_mark,
)
else:
if not isinstance(key, Hashable):
raise ConstructorError(
"while constructing a mapping",
node.start_mark,
"found unhashable key",
key_node.start_mark,
)
# construct but should be null
value = self.construct_object(value_node, deep=deep) # NOQA
self.check_set_key(node, key_node, typ, key)
if key_node.comment:
typ._yaml_add_comment(key_node.comment, key=key)
if value_node.comment:
typ._yaml_add_comment(value_node.comment, value=key)
typ.add(key)
def construct_yaml_seq(self, node):
# type: (Any) -> Any
data = CommentedSeq()
data._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
if node.comment:
data._yaml_add_comment(node.comment)
yield data
data.extend(self.construct_rt_sequence(node, data))
self.set_collection_style(data, node)
def construct_yaml_map(self, node):
# type: (Any) -> Any
data = CommentedMap()
data._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
yield data
self.construct_mapping(node, data, deep=True)
self.set_collection_style(data, node)
def set_collection_style(self, data, node):
# type: (Any, Any) -> None
if len(data) == 0:
return
if node.flow_style is True:
data.fa.set_flow_style()
elif node.flow_style is False:
data.fa.set_block_style()
def construct_yaml_object(self, node, cls):
# type: (Any, Any) -> Any
data = cls.__new__(cls)
yield data
if hasattr(data, "__setstate__"):
state = SafeConstructor.construct_mapping(self, node, deep=True)
data.__setstate__(state)
else:
state = SafeConstructor.construct_mapping(self, node)
data.__dict__.update(state)
def construct_yaml_omap(self, node):
# type: (Any) -> Any
# Note: we do now check for duplicate keys
omap = CommentedOrderedMap()
omap._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
if node.flow_style is True:
omap.fa.set_flow_style()
elif node.flow_style is False:
omap.fa.set_block_style()
yield omap
if node.comment:
omap._yaml_add_comment(node.comment[:2])
if len(node.comment) > 2:
omap.yaml_end_comment_extend(node.comment[2], clear=True)
if not isinstance(node, SequenceNode):
raise ConstructorError(
"while constructing an ordered map",
node.start_mark,
"expected a sequence, but found %s" % node.id,
node.start_mark,
)
for subnode in node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError(
"while constructing an ordered map",
node.start_mark,
"expected a mapping of length 1, but found %s" % subnode.id,
subnode.start_mark,
)
if len(subnode.value) != 1:
raise ConstructorError(
"while constructing an ordered map",
node.start_mark,
"expected a single mapping item, but found %d items"
% len(subnode.value),
subnode.start_mark,
)
key_node, value_node = subnode.value[0]
key = self.construct_object(key_node)
assert key not in omap
value = self.construct_object(value_node)
if key_node.comment:
omap._yaml_add_comment(key_node.comment, key=key)
if subnode.comment:
omap._yaml_add_comment(subnode.comment, key=key)
if value_node.comment:
omap._yaml_add_comment(value_node.comment, value=key)
omap[key] = value
def construct_yaml_set(self, node):
# type: (Any) -> Any
data = CommentedSet()
data._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
yield data
self.construct_setting(node, data)
def construct_undefined(self, node):
# type: (Any) -> Any
try:
if isinstance(node, MappingNode):
data = CommentedMap()
data._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
if node.flow_style is True:
data.fa.set_flow_style()
elif node.flow_style is False:
data.fa.set_block_style()
data.yaml_set_tag(node.tag)
yield data
if node.anchor:
data.yaml_set_anchor(node.anchor)
self.construct_mapping(node, data)
return
elif isinstance(node, ScalarNode):
data2 = TaggedScalar()
data2.value = self.construct_scalar(node)
data2.style = node.style
data2.yaml_set_tag(node.tag)
yield data2
if node.anchor:
data2.yaml_set_anchor(node.anchor, always_dump=True)
return
elif isinstance(node, SequenceNode):
data3 = CommentedSeq()
data3._yaml_set_line_col(node.start_mark.line, node.start_mark.column)
if node.flow_style is True:
data3.fa.set_flow_style()
elif node.flow_style is False:
data3.fa.set_block_style()
data3.yaml_set_tag(node.tag)
yield data3
if node.anchor:
data3.yaml_set_anchor(node.anchor)
data3.extend(self.construct_sequence(node))
return
except: # NOQA
pass
raise ConstructorError(
None,
None,
"could not determine a constructor for the tag %r" % utf8(node.tag),
node.start_mark,
)
def construct_yaml_timestamp(self, node, values=None):
# type: (Any, Any) -> Any
try:
match = self.timestamp_regexp.match(node.value)
except TypeError:
match = None
if match is None:
raise ConstructorError(
None,
None,
'failed to construct timestamp from "{}"'.format(node.value),
node.start_mark,
)
values = match.groupdict()
if not values["hour"]:
return SafeConstructor.construct_yaml_timestamp(self, node, values)
for part in ["t", "tz_sign", "tz_hour", "tz_minute"]:
if values[part]:
break
else:
return SafeConstructor.construct_yaml_timestamp(self, node, values)
year = int(values["year"])
month = int(values["month"])
day = int(values["day"])
hour = int(values["hour"])
minute = int(values["minute"])
second = int(values["second"])
fraction = 0
if values["fraction"]:
fraction_s = values["fraction"][:6]
while len(fraction_s) < 6:
fraction_s += "0"
fraction = int(fraction_s)
if len(values["fraction"]) > 6 and int(values["fraction"][6]) > 4:
fraction += 1
delta = None
if values["tz_sign"]:
tz_hour = int(values["tz_hour"])
minutes = values["tz_minute"]
tz_minute = int(minutes) if minutes else 0
delta = datetime.timedelta(hours=tz_hour, minutes=tz_minute)
if values["tz_sign"] == "-":
delta = -delta
if delta:
dt = datetime.datetime(year, month, day, hour, minute)
dt -= delta
data = TimeStamp(
dt.year, dt.month, dt.day, dt.hour, dt.minute, second, fraction
)
data._yaml["delta"] = delta
tz = values["tz_sign"] + values["tz_hour"]
if values["tz_minute"]:
tz += ":" + values["tz_minute"]
data._yaml["tz"] = tz
else:
data = TimeStamp(year, month, day, hour, minute, second, fraction)
if values["tz"]: # no delta
data._yaml["tz"] = values["tz"]
if values["t"]:
data._yaml["t"] = True
return data
def construct_yaml_bool(self, node):
# type: (Any) -> Any
b = SafeConstructor.construct_yaml_bool(self, node)
if node.anchor:
return ScalarBoolean(b, anchor=node.anchor)
return b
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:null", RoundTripConstructor.construct_yaml_null
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:bool", RoundTripConstructor.construct_yaml_bool
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:int", RoundTripConstructor.construct_yaml_int
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:float", RoundTripConstructor.construct_yaml_float
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:binary", RoundTripConstructor.construct_yaml_binary
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:timestamp", RoundTripConstructor.construct_yaml_timestamp
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:omap", RoundTripConstructor.construct_yaml_omap
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:pairs", RoundTripConstructor.construct_yaml_pairs
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:set", RoundTripConstructor.construct_yaml_set
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:str", RoundTripConstructor.construct_yaml_str
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:seq", RoundTripConstructor.construct_yaml_seq
)
RoundTripConstructor.add_constructor(
u"tag:yaml.org,2002:map", RoundTripConstructor.construct_yaml_map
)
RoundTripConstructor.add_constructor(None, RoundTripConstructor.construct_undefined)
srsly-release-v2.5.1/srsly/ruamel_yaml/cyaml.py 0000775 0000000 0000000 00000014647 14742310675 0021671 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
from _ruamel_yaml import CParser, CEmitter # type: ignore
from .constructor import Constructor, BaseConstructor, SafeConstructor
from .representer import Representer, SafeRepresenter, BaseRepresenter
from .resolver import Resolver, BaseResolver
if False: # MYPY
from typing import Any, Union, Optional # NOQA
from .compat import StreamTextType, StreamType, VersionType # NOQA
__all__ = [
"CBaseLoader",
"CSafeLoader",
"CLoader",
"CBaseDumper",
"CSafeDumper",
"CDumper",
]
# this includes some hacks to solve the usage of resolver by lower level
# parts of the parser
class CBaseLoader(CParser, BaseConstructor, BaseResolver): # type: ignore
def __init__(self, stream, version=None, preserve_quotes=None):
# type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
CParser.__init__(self, stream)
self._parser = self._composer = self
BaseConstructor.__init__(self, loader=self)
BaseResolver.__init__(self, loadumper=self)
# self.descend_resolver = self._resolver.descend_resolver
# self.ascend_resolver = self._resolver.ascend_resolver
# self.resolve = self._resolver.resolve
class CSafeLoader(CParser, SafeConstructor, Resolver): # type: ignore
def __init__(self, stream, version=None, preserve_quotes=None):
# type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
CParser.__init__(self, stream)
self._parser = self._composer = self
SafeConstructor.__init__(self, loader=self)
Resolver.__init__(self, loadumper=self)
# self.descend_resolver = self._resolver.descend_resolver
# self.ascend_resolver = self._resolver.ascend_resolver
# self.resolve = self._resolver.resolve
class CLoader(CParser, Constructor, Resolver): # type: ignore
def __init__(self, stream, version=None, preserve_quotes=None):
# type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
CParser.__init__(self, stream)
self._parser = self._composer = self
Constructor.__init__(self, loader=self)
Resolver.__init__(self, loadumper=self)
# self.descend_resolver = self._resolver.descend_resolver
# self.ascend_resolver = self._resolver.ascend_resolver
# self.resolve = self._resolver.resolve
class CBaseDumper(CEmitter, BaseRepresenter, BaseResolver): # type: ignore
def __init__(
self,
stream,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=None,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
CEmitter.__init__(
self,
stream,
canonical=canonical,
indent=indent,
width=width,
encoding=encoding,
allow_unicode=allow_unicode,
line_break=line_break,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
)
self._emitter = self._serializer = self._representer = self
BaseRepresenter.__init__(
self,
default_style=default_style,
default_flow_style=default_flow_style,
dumper=self,
)
BaseResolver.__init__(self, loadumper=self)
class CSafeDumper(CEmitter, SafeRepresenter, Resolver): # type: ignore
def __init__(
self,
stream,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=None,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
self._emitter = self._serializer = self._representer = self
CEmitter.__init__(
self,
stream,
canonical=canonical,
indent=indent,
width=width,
encoding=encoding,
allow_unicode=allow_unicode,
line_break=line_break,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
)
self._emitter = self._serializer = self._representer = self
SafeRepresenter.__init__(
self, default_style=default_style, default_flow_style=default_flow_style
)
Resolver.__init__(self)
class CDumper(CEmitter, Representer, Resolver): # type: ignore
def __init__(
self,
stream,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=None,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
CEmitter.__init__(
self,
stream,
canonical=canonical,
indent=indent,
width=width,
encoding=encoding,
allow_unicode=allow_unicode,
line_break=line_break,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
)
self._emitter = self._serializer = self._representer = self
Representer.__init__(
self, default_style=default_style, default_flow_style=default_flow_style
)
Resolver.__init__(self)
srsly-release-v2.5.1/srsly/ruamel_yaml/dumper.py 0000775 0000000 0000000 00000014652 14742310675 0022054 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
from .emitter import Emitter
from .serializer import Serializer
from .representer import (
Representer,
SafeRepresenter,
BaseRepresenter,
RoundTripRepresenter,
)
from .resolver import Resolver, BaseResolver, VersionedResolver
if False: # MYPY
from typing import Any, Dict, List, Union, Optional # NOQA
from .compat import StreamType, VersionType # NOQA
__all__ = ["BaseDumper", "SafeDumper", "Dumper", "RoundTripDumper"]
class BaseDumper(Emitter, Serializer, BaseRepresenter, BaseResolver):
def __init__(
self,
stream,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=None,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (Any, StreamType, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
Emitter.__init__(
self,
stream,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
block_seq_indent=block_seq_indent,
dumper=self,
)
Serializer.__init__(
self,
encoding=encoding,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
dumper=self,
)
BaseRepresenter.__init__(
self,
default_style=default_style,
default_flow_style=default_flow_style,
dumper=self,
)
BaseResolver.__init__(self, loadumper=self)
class SafeDumper(Emitter, Serializer, SafeRepresenter, Resolver):
def __init__(
self,
stream,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=None,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (StreamType, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
Emitter.__init__(
self,
stream,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
block_seq_indent=block_seq_indent,
dumper=self,
)
Serializer.__init__(
self,
encoding=encoding,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
dumper=self,
)
SafeRepresenter.__init__(
self,
default_style=default_style,
default_flow_style=default_flow_style,
dumper=self,
)
Resolver.__init__(self, loadumper=self)
class Dumper(Emitter, Serializer, Representer, Resolver):
def __init__(
self,
stream,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=None,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (StreamType, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
Emitter.__init__(
self,
stream,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
block_seq_indent=block_seq_indent,
dumper=self,
)
Serializer.__init__(
self,
encoding=encoding,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
dumper=self,
)
Representer.__init__(
self,
default_style=default_style,
default_flow_style=default_flow_style,
dumper=self,
)
Resolver.__init__(self, loadumper=self)
class RoundTripDumper(Emitter, Serializer, RoundTripRepresenter, VersionedResolver):
def __init__(
self,
stream,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=None,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (StreamType, Any, Optional[bool], Optional[int], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
Emitter.__init__(
self,
stream,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
block_seq_indent=block_seq_indent,
top_level_colon_align=top_level_colon_align,
prefix_colon=prefix_colon,
dumper=self,
)
Serializer.__init__(
self,
encoding=encoding,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
dumper=self,
)
RoundTripRepresenter.__init__(
self,
default_style=default_style,
default_flow_style=default_flow_style,
dumper=self,
)
VersionedResolver.__init__(self, loader=self)
srsly-release-v2.5.1/srsly/ruamel_yaml/emitter.py 0000775 0000000 0000000 00000176244 14742310675 0022237 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
from __future__ import print_function
# Emitter expects events obeying the following grammar:
# stream ::= STREAM-START document* STREAM-END
# document ::= DOCUMENT-START node DOCUMENT-END
# node ::= SCALAR | sequence | mapping
# sequence ::= SEQUENCE-START node* SEQUENCE-END
# mapping ::= MAPPING-START (node node)* MAPPING-END
import sys
from .error import YAMLError, YAMLStreamError
from .events import * # NOQA
# fmt: off
from .compat import utf8, text_type, PY2, nprint, dbg, DBG_EVENT, \
check_anchorname_char
# fmt: on
if False: # MYPY
from typing import Any, Dict, List, Union, Text, Tuple, Optional # NOQA
from .compat import StreamType # NOQA
__all__ = ["Emitter", "EmitterError"]
class EmitterError(YAMLError):
pass
class ScalarAnalysis(object):
def __init__(
self,
scalar,
empty,
multiline,
allow_flow_plain,
allow_block_plain,
allow_single_quoted,
allow_double_quoted,
allow_block,
):
# type: (Any, Any, Any, bool, bool, bool, bool, bool) -> None
self.scalar = scalar
self.empty = empty
self.multiline = multiline
self.allow_flow_plain = allow_flow_plain
self.allow_block_plain = allow_block_plain
self.allow_single_quoted = allow_single_quoted
self.allow_double_quoted = allow_double_quoted
self.allow_block = allow_block
class Indents(object):
# replacement for the list based stack of None/int
def __init__(self):
# type: () -> None
self.values = [] # type: List[Tuple[int, bool]]
def append(self, val, seq):
# type: (Any, Any) -> None
self.values.append((val, seq))
def pop(self):
# type: () -> Any
return self.values.pop()[0]
def last_seq(self):
# type: () -> bool
# return the seq(uence) value for the element added before the last one
# in increase_indent()
try:
return self.values[-2][1]
except IndexError:
return False
def seq_flow_align(self, seq_indent, column):
# type: (int, int) -> int
# extra spaces because of dash
if len(self.values) < 2 or not self.values[-1][1]:
return 0
# -1 for the dash
base = self.values[-1][0] if self.values[-1][0] is not None else 0
return base + seq_indent - column - 1
def __len__(self):
# type: () -> int
return len(self.values)
class Emitter(object):
# fmt: off
DEFAULT_TAG_PREFIXES = {
u'!': u'!',
u'tag:yaml.org,2002:': u'!!',
}
# fmt: on
MAX_SIMPLE_KEY_LENGTH = 128
def __init__(
self,
stream,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
brace_single_entry_mapping_in_flow_sequence=None,
dumper=None,
):
# type: (StreamType, Any, Optional[int], Optional[int], Optional[bool], Any, Optional[int], Optional[bool], Any, Optional[bool], Any) -> None # NOQA
self.dumper = dumper
if self.dumper is not None and getattr(self.dumper, "_emitter", None) is None:
self.dumper._emitter = self
self.stream = stream
# Encoding can be overriden by STREAM-START.
self.encoding = None # type: Optional[Text]
self.allow_space_break = None
# Emitter is a state machine with a stack of states to handle nested
# structures.
self.states = [] # type: List[Any]
self.state = self.expect_stream_start # type: Any
# Current event and the event queue.
self.events = [] # type: List[Any]
self.event = None # type: Any
# The current indentation level and the stack of previous indents.
self.indents = Indents()
self.indent = None # type: Optional[int]
# flow_context is an expanding/shrinking list consisting of '{' and '['
# for each unclosed flow context. If empty list that means block context
self.flow_context = [] # type: List[Text]
# Contexts.
self.root_context = False
self.sequence_context = False
self.mapping_context = False
self.simple_key_context = False
# Characteristics of the last emitted character:
# - current position.
# - is it a whitespace?
# - is it an indention character
# (indentation space, '-', '?', or ':')?
self.line = 0
self.column = 0
self.whitespace = True
self.indention = True
self.compact_seq_seq = True # dash after dash
self.compact_seq_map = True # key after dash
# self.compact_ms = False # dash after key, only when excplicit key with ?
self.no_newline = None # type: Optional[bool] # set if directly after `- `
# Whether the document requires an explicit document end indicator
self.open_ended = False
# colon handling
self.colon = u":"
self.prefixed_colon = (
self.colon if prefix_colon is None else prefix_colon + self.colon
)
# single entry mappings in flow sequence
self.brace_single_entry_mapping_in_flow_sequence = (
brace_single_entry_mapping_in_flow_sequence
) # NOQA
# Formatting details.
self.canonical = canonical
self.allow_unicode = allow_unicode
# set to False to get "\Uxxxxxxxx" for non-basic unicode like emojis
self.unicode_supplementary = sys.maxunicode > 0xFFFF
self.sequence_dash_offset = block_seq_indent if block_seq_indent else 0
self.top_level_colon_align = top_level_colon_align
self.best_sequence_indent = 2
self.requested_indent = indent # specific for literal zero indent
if indent and 1 < indent < 10:
self.best_sequence_indent = indent
self.best_map_indent = self.best_sequence_indent
# if self.best_sequence_indent < self.sequence_dash_offset + 1:
# self.best_sequence_indent = self.sequence_dash_offset + 1
self.best_width = 80
if width and width > self.best_sequence_indent * 2:
self.best_width = width
self.best_line_break = u"\n" # type: Any
if line_break in [u"\r", u"\n", u"\r\n"]:
self.best_line_break = line_break
# Tag prefixes.
self.tag_prefixes = None # type: Any
# Prepared anchor and tag.
self.prepared_anchor = None # type: Any
self.prepared_tag = None # type: Any
# Scalar analysis and style.
self.analysis = None # type: Any
self.style = None # type: Any
self.scalar_after_indicator = True # write a scalar on the same line as `---`
@property
def stream(self):
# type: () -> Any
try:
return self._stream
except AttributeError:
raise YAMLStreamError("output stream needs to specified")
@stream.setter
def stream(self, val):
# type: (Any) -> None
if val is None:
return
if not hasattr(val, "write"):
raise YAMLStreamError("stream argument needs to have a write() method")
self._stream = val
@property
def serializer(self):
# type: () -> Any
try:
if hasattr(self.dumper, "typ"):
return self.dumper.serializer
return self.dumper._serializer
except AttributeError:
return self # cyaml
@property
def flow_level(self):
# type: () -> int
return len(self.flow_context)
def dispose(self):
# type: () -> None
# Reset the state attributes (to clear self-references)
self.states = []
self.state = None
def emit(self, event):
# type: (Any) -> None
if dbg(DBG_EVENT):
nprint(event)
self.events.append(event)
while not self.need_more_events():
self.event = self.events.pop(0)
self.state()
self.event = None
# In some cases, we wait for a few next events before emitting.
def need_more_events(self):
# type: () -> bool
if not self.events:
return True
event = self.events[0]
if isinstance(event, DocumentStartEvent):
return self.need_events(1)
elif isinstance(event, SequenceStartEvent):
return self.need_events(2)
elif isinstance(event, MappingStartEvent):
return self.need_events(3)
else:
return False
def need_events(self, count):
# type: (int) -> bool
level = 0
for event in self.events[1:]:
if isinstance(event, (DocumentStartEvent, CollectionStartEvent)):
level += 1
elif isinstance(event, (DocumentEndEvent, CollectionEndEvent)):
level -= 1
elif isinstance(event, StreamEndEvent):
level = -1
if level < 0:
return False
return len(self.events) < count + 1
def increase_indent(self, flow=False, sequence=None, indentless=False):
# type: (bool, Optional[bool], bool) -> None
self.indents.append(self.indent, sequence)
if self.indent is None: # top level
if flow:
# self.indent = self.best_sequence_indent if self.indents.last_seq() else \
# self.best_map_indent
# self.indent = self.best_sequence_indent
self.indent = self.requested_indent
else:
self.indent = 0
elif not indentless:
self.indent += (
self.best_sequence_indent
if self.indents.last_seq()
else self.best_map_indent
)
# if self.indents.last_seq():
# if self.indent == 0: # top level block sequence
# self.indent = self.best_sequence_indent - self.sequence_dash_offset
# else:
# self.indent += self.best_sequence_indent
# else:
# self.indent += self.best_map_indent
# States.
# Stream handlers.
def expect_stream_start(self):
# type: () -> None
if isinstance(self.event, StreamStartEvent):
if PY2:
if self.event.encoding and not getattr(self.stream, "encoding", None):
self.encoding = self.event.encoding
else:
if self.event.encoding and not hasattr(self.stream, "encoding"):
self.encoding = self.event.encoding
self.write_stream_start()
self.state = self.expect_first_document_start
else:
raise EmitterError("expected StreamStartEvent, but got %s" % (self.event,))
def expect_nothing(self):
# type: () -> None
raise EmitterError("expected nothing, but got %s" % (self.event,))
# Document handlers.
def expect_first_document_start(self):
# type: () -> Any
return self.expect_document_start(first=True)
def expect_document_start(self, first=False):
# type: (bool) -> None
if isinstance(self.event, DocumentStartEvent):
if (self.event.version or self.event.tags) and self.open_ended:
self.write_indicator(u"...", True)
self.write_indent()
if self.event.version:
version_text = self.prepare_version(self.event.version)
self.write_version_directive(version_text)
self.tag_prefixes = self.DEFAULT_TAG_PREFIXES.copy()
if self.event.tags:
handles = sorted(self.event.tags.keys())
for handle in handles:
prefix = self.event.tags[handle]
self.tag_prefixes[prefix] = handle
handle_text = self.prepare_tag_handle(handle)
prefix_text = self.prepare_tag_prefix(prefix)
self.write_tag_directive(handle_text, prefix_text)
implicit = (
first
and not self.event.explicit
and not self.canonical
and not self.event.version
and not self.event.tags
and not self.check_empty_document()
)
if not implicit:
self.write_indent()
self.write_indicator(u"---", True)
if self.canonical:
self.write_indent()
self.state = self.expect_document_root
elif isinstance(self.event, StreamEndEvent):
if self.open_ended:
self.write_indicator(u"...", True)
self.write_indent()
self.write_stream_end()
self.state = self.expect_nothing
else:
raise EmitterError(
"expected DocumentStartEvent, but got %s" % (self.event,)
)
def expect_document_end(self):
# type: () -> None
if isinstance(self.event, DocumentEndEvent):
self.write_indent()
if self.event.explicit:
self.write_indicator(u"...", True)
self.write_indent()
self.flush_stream()
self.state = self.expect_document_start
else:
raise EmitterError("expected DocumentEndEvent, but got %s" % (self.event,))
def expect_document_root(self):
# type: () -> None
self.states.append(self.expect_document_end)
self.expect_node(root=True)
# Node handlers.
def expect_node(self, root=False, sequence=False, mapping=False, simple_key=False):
# type: (bool, bool, bool, bool) -> None
self.root_context = root
self.sequence_context = sequence # not used in PyYAML
self.mapping_context = mapping
self.simple_key_context = simple_key
if isinstance(self.event, AliasEvent):
self.expect_alias()
elif isinstance(self.event, (ScalarEvent, CollectionStartEvent)):
if (
self.process_anchor(u"&")
and isinstance(self.event, ScalarEvent)
and self.sequence_context
):
self.sequence_context = False
if (
root
and isinstance(self.event, ScalarEvent)
and not self.scalar_after_indicator
):
self.write_indent()
self.process_tag()
if isinstance(self.event, ScalarEvent):
# nprint('@', self.indention, self.no_newline, self.column)
self.expect_scalar()
elif isinstance(self.event, SequenceStartEvent):
# nprint('@', self.indention, self.no_newline, self.column)
i2, n2 = self.indention, self.no_newline # NOQA
if self.event.comment:
if self.event.flow_style is False and self.event.comment:
if self.write_post_comment(self.event):
self.indention = False
self.no_newline = True
if self.write_pre_comment(self.event):
self.indention = i2
self.no_newline = not self.indention
if (
self.flow_level
or self.canonical
or self.event.flow_style
or self.check_empty_sequence()
):
self.expect_flow_sequence()
else:
self.expect_block_sequence()
elif isinstance(self.event, MappingStartEvent):
if self.event.flow_style is False and self.event.comment:
self.write_post_comment(self.event)
if self.event.comment and self.event.comment[1]:
self.write_pre_comment(self.event)
if (
self.flow_level
or self.canonical
or self.event.flow_style
or self.check_empty_mapping()
):
self.expect_flow_mapping(single=self.event.nr_items == 1)
else:
self.expect_block_mapping()
else:
raise EmitterError("expected NodeEvent, but got %s" % (self.event,))
def expect_alias(self):
# type: () -> None
if self.event.anchor is None:
raise EmitterError("anchor is not specified for alias")
self.process_anchor(u"*")
self.state = self.states.pop()
def expect_scalar(self):
# type: () -> None
self.increase_indent(flow=True)
self.process_scalar()
self.indent = self.indents.pop()
self.state = self.states.pop()
# Flow sequence handlers.
def expect_flow_sequence(self):
# type: () -> None
ind = self.indents.seq_flow_align(self.best_sequence_indent, self.column)
self.write_indicator(u" " * ind + u"[", True, whitespace=True)
self.increase_indent(flow=True, sequence=True)
self.flow_context.append("[")
self.state = self.expect_first_flow_sequence_item
def expect_first_flow_sequence_item(self):
# type: () -> None
if isinstance(self.event, SequenceEndEvent):
self.indent = self.indents.pop()
popped = self.flow_context.pop()
assert popped == "["
self.write_indicator(u"]", False)
if self.event.comment and self.event.comment[0]:
# eol comment on empty flow sequence
self.write_post_comment(self.event)
elif self.flow_level == 0:
self.write_line_break()
self.state = self.states.pop()
else:
if self.canonical or self.column > self.best_width:
self.write_indent()
self.states.append(self.expect_flow_sequence_item)
self.expect_node(sequence=True)
def expect_flow_sequence_item(self):
# type: () -> None
if isinstance(self.event, SequenceEndEvent):
self.indent = self.indents.pop()
popped = self.flow_context.pop()
assert popped == "["
if self.canonical:
self.write_indicator(u",", False)
self.write_indent()
self.write_indicator(u"]", False)
if self.event.comment and self.event.comment[0]:
# eol comment on flow sequence
self.write_post_comment(self.event)
else:
self.no_newline = False
self.state = self.states.pop()
else:
self.write_indicator(u",", False)
if self.canonical or self.column > self.best_width:
self.write_indent()
self.states.append(self.expect_flow_sequence_item)
self.expect_node(sequence=True)
# Flow mapping handlers.
def expect_flow_mapping(self, single=False):
# type: (Optional[bool]) -> None
ind = self.indents.seq_flow_align(self.best_sequence_indent, self.column)
map_init = u"{"
if (
single
and self.flow_level
and self.flow_context[-1] == "["
and not self.canonical
and not self.brace_single_entry_mapping_in_flow_sequence
):
# single map item with flow context, no curly braces necessary
map_init = u""
self.write_indicator(u" " * ind + map_init, True, whitespace=True)
self.flow_context.append(map_init)
self.increase_indent(flow=True, sequence=False)
self.state = self.expect_first_flow_mapping_key
def expect_first_flow_mapping_key(self):
# type: () -> None
if isinstance(self.event, MappingEndEvent):
self.indent = self.indents.pop()
popped = self.flow_context.pop()
assert popped == "{" # empty flow mapping
self.write_indicator(u"}", False)
if self.event.comment and self.event.comment[0]:
# eol comment on empty mapping
self.write_post_comment(self.event)
elif self.flow_level == 0:
self.write_line_break()
self.state = self.states.pop()
else:
if self.canonical or self.column > self.best_width:
self.write_indent()
if not self.canonical and self.check_simple_key():
self.states.append(self.expect_flow_mapping_simple_value)
self.expect_node(mapping=True, simple_key=True)
else:
self.write_indicator(u"?", True)
self.states.append(self.expect_flow_mapping_value)
self.expect_node(mapping=True)
def expect_flow_mapping_key(self):
# type: () -> None
if isinstance(self.event, MappingEndEvent):
# if self.event.comment and self.event.comment[1]:
# self.write_pre_comment(self.event)
self.indent = self.indents.pop()
popped = self.flow_context.pop()
assert popped in [u"{", u""]
if self.canonical:
self.write_indicator(u",", False)
self.write_indent()
if popped != u"":
self.write_indicator(u"}", False)
if self.event.comment and self.event.comment[0]:
# eol comment on flow mapping, never reached on empty mappings
self.write_post_comment(self.event)
else:
self.no_newline = False
self.state = self.states.pop()
else:
self.write_indicator(u",", False)
if self.canonical or self.column > self.best_width:
self.write_indent()
if not self.canonical and self.check_simple_key():
self.states.append(self.expect_flow_mapping_simple_value)
self.expect_node(mapping=True, simple_key=True)
else:
self.write_indicator(u"?", True)
self.states.append(self.expect_flow_mapping_value)
self.expect_node(mapping=True)
def expect_flow_mapping_simple_value(self):
# type: () -> None
self.write_indicator(self.prefixed_colon, False)
self.states.append(self.expect_flow_mapping_key)
self.expect_node(mapping=True)
def expect_flow_mapping_value(self):
# type: () -> None
if self.canonical or self.column > self.best_width:
self.write_indent()
self.write_indicator(self.prefixed_colon, True)
self.states.append(self.expect_flow_mapping_key)
self.expect_node(mapping=True)
# Block sequence handlers.
def expect_block_sequence(self):
# type: () -> None
if self.mapping_context:
indentless = not self.indention
else:
indentless = False
if not self.compact_seq_seq and self.column != 0:
self.write_line_break()
self.increase_indent(flow=False, sequence=True, indentless=indentless)
self.state = self.expect_first_block_sequence_item
def expect_first_block_sequence_item(self):
# type: () -> Any
return self.expect_block_sequence_item(first=True)
def expect_block_sequence_item(self, first=False):
# type: (bool) -> None
if not first and isinstance(self.event, SequenceEndEvent):
if self.event.comment and self.event.comment[1]:
# final comments on a block list e.g. empty line
self.write_pre_comment(self.event)
self.indent = self.indents.pop()
self.state = self.states.pop()
self.no_newline = False
else:
if self.event.comment and self.event.comment[1]:
self.write_pre_comment(self.event)
nonl = self.no_newline if self.column == 0 else False
self.write_indent()
ind = self.sequence_dash_offset # if len(self.indents) > 1 else 0
self.write_indicator(u" " * ind + u"-", True, indention=True)
if nonl or self.sequence_dash_offset + 2 > self.best_sequence_indent:
self.no_newline = True
self.states.append(self.expect_block_sequence_item)
self.expect_node(sequence=True)
# Block mapping handlers.
def expect_block_mapping(self):
# type: () -> None
if not self.mapping_context and not (self.compact_seq_map or self.column == 0):
self.write_line_break()
self.increase_indent(flow=False, sequence=False)
self.state = self.expect_first_block_mapping_key
def expect_first_block_mapping_key(self):
# type: () -> None
return self.expect_block_mapping_key(first=True)
def expect_block_mapping_key(self, first=False):
# type: (Any) -> None
if not first and isinstance(self.event, MappingEndEvent):
if self.event.comment and self.event.comment[1]:
# final comments from a doc
self.write_pre_comment(self.event)
self.indent = self.indents.pop()
self.state = self.states.pop()
else:
if self.event.comment and self.event.comment[1]:
# final comments from a doc
self.write_pre_comment(self.event)
self.write_indent()
if self.check_simple_key():
if not isinstance(
self.event, (SequenceStartEvent, MappingStartEvent)
): # sequence keys
try:
if self.event.style == "?":
self.write_indicator(u"?", True, indention=True)
except AttributeError: # aliases have no style
pass
self.states.append(self.expect_block_mapping_simple_value)
self.expect_node(mapping=True, simple_key=True)
if isinstance(self.event, AliasEvent):
self.stream.write(u" ")
else:
self.write_indicator(u"?", True, indention=True)
self.states.append(self.expect_block_mapping_value)
self.expect_node(mapping=True)
def expect_block_mapping_simple_value(self):
# type: () -> None
if getattr(self.event, "style", None) != "?":
# prefix = u''
if self.indent == 0 and self.top_level_colon_align is not None:
# write non-prefixed colon
c = u" " * (self.top_level_colon_align - self.column) + self.colon
else:
c = self.prefixed_colon
self.write_indicator(c, False)
self.states.append(self.expect_block_mapping_key)
self.expect_node(mapping=True)
def expect_block_mapping_value(self):
# type: () -> None
self.write_indent()
self.write_indicator(self.prefixed_colon, True, indention=True)
self.states.append(self.expect_block_mapping_key)
self.expect_node(mapping=True)
# Checkers.
def check_empty_sequence(self):
# type: () -> bool
return (
isinstance(self.event, SequenceStartEvent)
and bool(self.events)
and isinstance(self.events[0], SequenceEndEvent)
)
def check_empty_mapping(self):
# type: () -> bool
return (
isinstance(self.event, MappingStartEvent)
and bool(self.events)
and isinstance(self.events[0], MappingEndEvent)
)
def check_empty_document(self):
# type: () -> bool
if not isinstance(self.event, DocumentStartEvent) or not self.events:
return False
event = self.events[0]
return (
isinstance(event, ScalarEvent)
and event.anchor is None
and event.tag is None
and event.implicit
and event.value == ""
)
def check_simple_key(self):
# type: () -> bool
length = 0
if isinstance(self.event, NodeEvent) and self.event.anchor is not None:
if self.prepared_anchor is None:
self.prepared_anchor = self.prepare_anchor(self.event.anchor)
length += len(self.prepared_anchor)
if (
isinstance(self.event, (ScalarEvent, CollectionStartEvent))
and self.event.tag is not None
):
if self.prepared_tag is None:
self.prepared_tag = self.prepare_tag(self.event.tag)
length += len(self.prepared_tag)
if isinstance(self.event, ScalarEvent):
if self.analysis is None:
self.analysis = self.analyze_scalar(self.event.value)
length += len(self.analysis.scalar)
return length < self.MAX_SIMPLE_KEY_LENGTH and (
isinstance(self.event, AliasEvent)
or (
isinstance(self.event, SequenceStartEvent)
and self.event.flow_style is True
)
or (
isinstance(self.event, MappingStartEvent)
and self.event.flow_style is True
)
or (
isinstance(self.event, ScalarEvent)
# if there is an explicit style for an empty string, it is a simple key
and not (self.analysis.empty and self.style and self.style not in "'\"")
and not self.analysis.multiline
)
or self.check_empty_sequence()
or self.check_empty_mapping()
)
# Anchor, Tag, and Scalar processors.
def process_anchor(self, indicator):
# type: (Any) -> bool
if self.event.anchor is None:
self.prepared_anchor = None
return False
if self.prepared_anchor is None:
self.prepared_anchor = self.prepare_anchor(self.event.anchor)
if self.prepared_anchor:
self.write_indicator(indicator + self.prepared_anchor, True)
# issue 288
self.no_newline = False
self.prepared_anchor = None
return True
def process_tag(self):
# type: () -> None
tag = self.event.tag
if isinstance(self.event, ScalarEvent):
if self.style is None:
self.style = self.choose_scalar_style()
if (not self.canonical or tag is None) and (
(self.style == "" and self.event.implicit[0])
or (self.style != "" and self.event.implicit[1])
):
self.prepared_tag = None
return
if self.event.implicit[0] and tag is None:
tag = u"!"
self.prepared_tag = None
else:
if (not self.canonical or tag is None) and self.event.implicit:
self.prepared_tag = None
return
if tag is None:
raise EmitterError("tag is not specified")
if self.prepared_tag is None:
self.prepared_tag = self.prepare_tag(tag)
if self.prepared_tag:
self.write_indicator(self.prepared_tag, True)
if (
self.sequence_context
and not self.flow_level
and isinstance(self.event, ScalarEvent)
):
self.no_newline = True
self.prepared_tag = None
def choose_scalar_style(self):
# type: () -> Any
if self.analysis is None:
self.analysis = self.analyze_scalar(self.event.value)
if self.event.style == '"' or self.canonical:
return '"'
if (not self.event.style or self.event.style == "?") and (
self.event.implicit[0] or not self.event.implicit[2]
):
if not (
self.simple_key_context
and (self.analysis.empty or self.analysis.multiline)
) and (
self.flow_level
and self.analysis.allow_flow_plain
or (not self.flow_level and self.analysis.allow_block_plain)
):
return ""
self.analysis.allow_block = True
if self.event.style and self.event.style in "|>":
if (
not self.flow_level
and not self.simple_key_context
and self.analysis.allow_block
):
return self.event.style
if not self.event.style and self.analysis.allow_double_quoted:
if "'" in self.event.value or "\n" in self.event.value:
return '"'
if not self.event.style or self.event.style == "'":
if self.analysis.allow_single_quoted and not (
self.simple_key_context and self.analysis.multiline
):
return "'"
return '"'
def process_scalar(self):
# type: () -> None
if self.analysis is None:
self.analysis = self.analyze_scalar(self.event.value)
if self.style is None:
self.style = self.choose_scalar_style()
split = not self.simple_key_context
# if self.analysis.multiline and split \
# and (not self.style or self.style in '\'\"'):
# self.write_indent()
# nprint('xx', self.sequence_context, self.flow_level)
if self.sequence_context and not self.flow_level:
self.write_indent()
if self.style == '"':
self.write_double_quoted(self.analysis.scalar, split)
elif self.style == "'":
self.write_single_quoted(self.analysis.scalar, split)
elif self.style == ">":
self.write_folded(self.analysis.scalar)
elif self.style == "|":
self.write_literal(self.analysis.scalar, self.event.comment)
else:
self.write_plain(self.analysis.scalar, split)
self.analysis = None
self.style = None
if self.event.comment:
self.write_post_comment(self.event)
# Analyzers.
def prepare_version(self, version):
# type: (Any) -> Any
major, minor = version
if major != 1:
raise EmitterError("unsupported YAML version: %d.%d" % (major, minor))
return u"%d.%d" % (major, minor)
def prepare_tag_handle(self, handle):
# type: (Any) -> Any
if not handle:
raise EmitterError("tag handle must not be empty")
if handle[0] != u"!" or handle[-1] != u"!":
raise EmitterError(
"tag handle must start and end with '!': %r" % (utf8(handle))
)
for ch in handle[1:-1]:
if not (
u"0" <= ch <= u"9"
or u"A" <= ch <= u"Z"
or u"a" <= ch <= u"z"
or ch in u"-_"
):
raise EmitterError(
"invalid character %r in the tag handle: %r"
% (utf8(ch), utf8(handle))
)
return handle
def prepare_tag_prefix(self, prefix):
# type: (Any) -> Any
if not prefix:
raise EmitterError("tag prefix must not be empty")
chunks = [] # type: List[Any]
start = end = 0
if prefix[0] == u"!":
end = 1
ch_set = u"-;/?:@&=+$,_.~*'()[]"
if self.dumper:
version = getattr(self.dumper, "version", (1, 2))
if version is None or version >= (1, 2):
ch_set += u"#"
while end < len(prefix):
ch = prefix[end]
if (
u"0" <= ch <= u"9"
or u"A" <= ch <= u"Z"
or u"a" <= ch <= u"z"
or ch in ch_set
):
end += 1
else:
if start < end:
chunks.append(prefix[start:end])
start = end = end + 1
data = utf8(ch)
for ch in data:
chunks.append(u"%%%02X" % ord(ch))
if start < end:
chunks.append(prefix[start:end])
return "".join(chunks)
def prepare_tag(self, tag):
# type: (Any) -> Any
if not tag:
raise EmitterError("tag must not be empty")
if tag == u"!":
return tag
handle = None
suffix = tag
prefixes = sorted(self.tag_prefixes.keys())
for prefix in prefixes:
if tag.startswith(prefix) and (prefix == u"!" or len(prefix) < len(tag)):
handle = self.tag_prefixes[prefix]
suffix = tag[len(prefix) :]
chunks = [] # type: List[Any]
start = end = 0
ch_set = u"-;/?:@&=+$,_.~*'()[]"
if self.dumper:
version = getattr(self.dumper, "version", (1, 2))
if version is None or version >= (1, 2):
ch_set += u"#"
while end < len(suffix):
ch = suffix[end]
if (
u"0" <= ch <= u"9"
or u"A" <= ch <= u"Z"
or u"a" <= ch <= u"z"
or ch in ch_set
or (ch == u"!" and handle != u"!")
):
end += 1
else:
if start < end:
chunks.append(suffix[start:end])
start = end = end + 1
data = utf8(ch)
for ch in data:
chunks.append(u"%%%02X" % ord(ch))
if start < end:
chunks.append(suffix[start:end])
suffix_text = "".join(chunks)
if handle:
return u"%s%s" % (handle, suffix_text)
else:
return u"!<%s>" % suffix_text
def prepare_anchor(self, anchor):
# type: (Any) -> Any
if not anchor:
raise EmitterError("anchor must not be empty")
for ch in anchor:
if not check_anchorname_char(ch):
raise EmitterError(
"invalid character %r in the anchor: %r" % (utf8(ch), utf8(anchor))
)
return anchor
def analyze_scalar(self, scalar):
# type: (Any) -> Any
# Empty scalar is a special case.
if not scalar:
return ScalarAnalysis(
scalar=scalar,
empty=True,
multiline=False,
allow_flow_plain=False,
allow_block_plain=True,
allow_single_quoted=True,
allow_double_quoted=True,
allow_block=False,
)
# Indicators and special characters.
block_indicators = False
flow_indicators = False
line_breaks = False
special_characters = False
# Important whitespace combinations.
leading_space = False
leading_break = False
trailing_space = False
trailing_break = False
break_space = False
space_break = False
# Check document indicators.
if scalar.startswith(u"---") or scalar.startswith(u"..."):
block_indicators = True
flow_indicators = True
# First character or preceded by a whitespace.
preceeded_by_whitespace = True
# Last character or followed by a whitespace.
followed_by_whitespace = (
len(scalar) == 1 or scalar[1] in u"\0 \t\r\n\x85\u2028\u2029"
)
# The previous character is a space.
previous_space = False
# The previous character is a break.
previous_break = False
index = 0
while index < len(scalar):
ch = scalar[index]
# Check for indicators.
if index == 0:
# Leading indicators are special characters.
if ch in u"#,[]{}&*!|>'\"%@`":
flow_indicators = True
block_indicators = True
if ch in u"?:": # ToDo
if self.serializer.use_version == (1, 1):
flow_indicators = True
elif len(scalar) == 1: # single character
flow_indicators = True
if followed_by_whitespace:
block_indicators = True
if ch == u"-" and followed_by_whitespace:
flow_indicators = True
block_indicators = True
else:
# Some indicators cannot appear within a scalar as well.
if ch in u",[]{}": # http://yaml.org/spec/1.2/spec.html#id2788859
flow_indicators = True
if ch == u"?" and self.serializer.use_version == (1, 1):
flow_indicators = True
if ch == u":":
if followed_by_whitespace:
flow_indicators = True
block_indicators = True
if ch == u"#" and preceeded_by_whitespace:
flow_indicators = True
block_indicators = True
# Check for line breaks, special, and unicode characters.
if ch in u"\n\x85\u2028\u2029":
line_breaks = True
if not (ch == u"\n" or u"\x20" <= ch <= u"\x7E"):
if (
ch == u"\x85"
or u"\xA0" <= ch <= u"\uD7FF"
or u"\uE000" <= ch <= u"\uFFFD"
or (
self.unicode_supplementary
and (u"\U00010000" <= ch <= u"\U0010FFFF")
)
) and ch != u"\uFEFF":
# unicode_characters = True
if not self.allow_unicode:
special_characters = True
else:
special_characters = True
# Detect important whitespace combinations.
if ch == u" ":
if index == 0:
leading_space = True
if index == len(scalar) - 1:
trailing_space = True
if previous_break:
break_space = True
previous_space = True
previous_break = False
elif ch in u"\n\x85\u2028\u2029":
if index == 0:
leading_break = True
if index == len(scalar) - 1:
trailing_break = True
if previous_space:
space_break = True
previous_space = False
previous_break = True
else:
previous_space = False
previous_break = False
# Prepare for the next character.
index += 1
preceeded_by_whitespace = ch in u"\0 \t\r\n\x85\u2028\u2029"
followed_by_whitespace = (
index + 1 >= len(scalar)
or scalar[index + 1] in u"\0 \t\r\n\x85\u2028\u2029"
)
# Let's decide what styles are allowed.
allow_flow_plain = True
allow_block_plain = True
allow_single_quoted = True
allow_double_quoted = True
allow_block = True
# Leading and trailing whitespaces are bad for plain scalars.
if leading_space or leading_break or trailing_space or trailing_break:
allow_flow_plain = allow_block_plain = False
# We do not permit trailing spaces for block scalars.
if trailing_space:
allow_block = False
# Spaces at the beginning of a new line are only acceptable for block
# scalars.
if break_space:
allow_flow_plain = allow_block_plain = allow_single_quoted = False
# Spaces followed by breaks, as well as special character are only
# allowed for double quoted scalars.
if special_characters:
allow_flow_plain = (
allow_block_plain
) = allow_single_quoted = allow_block = False
elif space_break:
allow_flow_plain = allow_block_plain = allow_single_quoted = False
if not self.allow_space_break:
allow_block = False
# Although the plain scalar writer supports breaks, we never emit
# multiline plain scalars.
if line_breaks:
allow_flow_plain = allow_block_plain = False
# Flow indicators are forbidden for flow plain scalars.
if flow_indicators:
allow_flow_plain = False
# Block indicators are forbidden for block plain scalars.
if block_indicators:
allow_block_plain = False
return ScalarAnalysis(
scalar=scalar,
empty=False,
multiline=line_breaks,
allow_flow_plain=allow_flow_plain,
allow_block_plain=allow_block_plain,
allow_single_quoted=allow_single_quoted,
allow_double_quoted=allow_double_quoted,
allow_block=allow_block,
)
# Writers.
def flush_stream(self):
# type: () -> None
if hasattr(self.stream, "flush"):
self.stream.flush()
def write_stream_start(self):
# type: () -> None
# Write BOM if needed.
if self.encoding and self.encoding.startswith("utf-16"):
self.stream.write(u"\uFEFF".encode(self.encoding))
def write_stream_end(self):
# type: () -> None
self.flush_stream()
def write_indicator(
self, indicator, need_whitespace, whitespace=False, indention=False
):
# type: (Any, Any, bool, bool) -> None
if self.whitespace or not need_whitespace:
data = indicator
else:
data = u" " + indicator
self.whitespace = whitespace
self.indention = self.indention and indention
self.column += len(data)
self.open_ended = False
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
def write_indent(self):
# type: () -> None
indent = self.indent or 0
if (
not self.indention
or self.column > indent
or (self.column == indent and not self.whitespace)
):
if bool(self.no_newline):
self.no_newline = False
else:
self.write_line_break()
if self.column < indent:
self.whitespace = True
data = u" " * (indent - self.column)
self.column = indent
if self.encoding:
data = data.encode(self.encoding)
self.stream.write(data)
def write_line_break(self, data=None):
# type: (Any) -> None
if data is None:
data = self.best_line_break
self.whitespace = True
self.indention = True
self.line += 1
self.column = 0
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
def write_version_directive(self, version_text):
# type: (Any) -> None
data = u"%%YAML %s" % version_text
if self.encoding:
data = data.encode(self.encoding)
self.stream.write(data)
self.write_line_break()
def write_tag_directive(self, handle_text, prefix_text):
# type: (Any, Any) -> None
data = u"%%TAG %s %s" % (handle_text, prefix_text)
if self.encoding:
data = data.encode(self.encoding)
self.stream.write(data)
self.write_line_break()
# Scalar streams.
def write_single_quoted(self, text, split=True):
# type: (Any, Any) -> None
if self.root_context:
if self.requested_indent is not None:
self.write_line_break()
if self.requested_indent != 0:
self.write_indent()
self.write_indicator(u"'", True)
spaces = False
breaks = False
start = end = 0
while end <= len(text):
ch = None
if end < len(text):
ch = text[end]
if spaces:
if ch is None or ch != u" ":
if (
start + 1 == end
and self.column > self.best_width
and split
and start != 0
and end != len(text)
):
self.write_indent()
else:
data = text[start:end]
self.column += len(data)
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
start = end
elif breaks:
if ch is None or ch not in u"\n\x85\u2028\u2029":
if text[start] == u"\n":
self.write_line_break()
for br in text[start:end]:
if br == u"\n":
self.write_line_break()
else:
self.write_line_break(br)
self.write_indent()
start = end
else:
if ch is None or ch in u" \n\x85\u2028\u2029" or ch == u"'":
if start < end:
data = text[start:end]
self.column += len(data)
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
start = end
if ch == u"'":
data = u"''"
self.column += 2
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
start = end + 1
if ch is not None:
spaces = ch == u" "
breaks = ch in u"\n\x85\u2028\u2029"
end += 1
self.write_indicator(u"'", False)
ESCAPE_REPLACEMENTS = {
u"\0": u"0",
u"\x07": u"a",
u"\x08": u"b",
u"\x09": u"t",
u"\x0A": u"n",
u"\x0B": u"v",
u"\x0C": u"f",
u"\x0D": u"r",
u"\x1B": u"e",
u'"': u'"',
u"\\": u"\\",
u"\x85": u"N",
u"\xA0": u"_",
u"\u2028": u"L",
u"\u2029": u"P",
}
def write_double_quoted(self, text, split=True):
# type: (Any, Any) -> None
if self.root_context:
if self.requested_indent is not None:
self.write_line_break()
if self.requested_indent != 0:
self.write_indent()
self.write_indicator(u'"', True)
start = end = 0
while end <= len(text):
ch = None
if end < len(text):
ch = text[end]
if (
ch is None
or ch in u'"\\\x85\u2028\u2029\uFEFF'
or not (
u"\x20" <= ch <= u"\x7E"
or (
self.allow_unicode
and (u"\xA0" <= ch <= u"\uD7FF" or u"\uE000" <= ch <= u"\uFFFD")
)
)
):
if start < end:
data = text[start:end]
self.column += len(data)
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
start = end
if ch is not None:
if ch in self.ESCAPE_REPLACEMENTS:
data = u"\\" + self.ESCAPE_REPLACEMENTS[ch]
elif ch <= u"\xFF":
data = u"\\x%02X" % ord(ch)
elif ch <= u"\uFFFF":
data = u"\\u%04X" % ord(ch)
else:
data = u"\\U%08X" % ord(ch)
self.column += len(data)
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
start = end + 1
if (
0 < end < len(text) - 1
and (ch == u" " or start >= end)
and self.column + (end - start) > self.best_width
and split
):
data = text[start:end] + u"\\"
if start < end:
start = end
self.column += len(data)
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
self.write_indent()
self.whitespace = False
self.indention = False
if text[start] == u" ":
data = u"\\"
self.column += len(data)
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
end += 1
self.write_indicator(u'"', False)
def determine_block_hints(self, text):
# type: (Any) -> Any
indent = 0
indicator = u""
hints = u""
if text:
if text[0] in u" \n\x85\u2028\u2029":
indent = self.best_sequence_indent
hints += text_type(indent)
elif self.root_context:
for end in ["\n---", "\n..."]:
pos = 0
while True:
pos = text.find(end, pos)
if pos == -1:
break
try:
if text[pos + 4] in " \r\n":
break
except IndexError:
pass
pos += 1
if pos > -1:
break
if pos > 0:
indent = self.best_sequence_indent
if text[-1] not in u"\n\x85\u2028\u2029":
indicator = u"-"
elif len(text) == 1 or text[-2] in u"\n\x85\u2028\u2029":
indicator = u"+"
hints += indicator
return hints, indent, indicator
def write_folded(self, text):
# type: (Any) -> None
hints, _indent, _indicator = self.determine_block_hints(text)
self.write_indicator(u">" + hints, True)
if _indicator == u"+":
self.open_ended = True
self.write_line_break()
leading_space = True
spaces = False
breaks = True
start = end = 0
while end <= len(text):
ch = None
if end < len(text):
ch = text[end]
if breaks:
if ch is None or ch not in u"\n\x85\u2028\u2029\a":
if (
not leading_space
and ch is not None
and ch != u" "
and text[start] == u"\n"
):
self.write_line_break()
leading_space = ch == u" "
for br in text[start:end]:
if br == u"\n":
self.write_line_break()
else:
self.write_line_break(br)
if ch is not None:
self.write_indent()
start = end
elif spaces:
if ch != u" ":
if start + 1 == end and self.column > self.best_width:
self.write_indent()
else:
data = text[start:end]
self.column += len(data)
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
start = end
else:
if ch is None or ch in u" \n\x85\u2028\u2029\a":
data = text[start:end]
self.column += len(data)
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
if ch == u"\a":
if end < (len(text) - 1) and not text[end + 2].isspace():
self.write_line_break()
self.write_indent()
end += 2 # \a and the space that is inserted on the fold
else:
raise EmitterError(
"unexcpected fold indicator \\a before space"
)
if ch is None:
self.write_line_break()
start = end
if ch is not None:
breaks = ch in u"\n\x85\u2028\u2029"
spaces = ch == u" "
end += 1
def write_literal(self, text, comment=None):
# type: (Any, Any) -> None
hints, _indent, _indicator = self.determine_block_hints(text)
self.write_indicator(u"|" + hints, True)
try:
comment = comment[1][0]
if comment:
self.stream.write(comment)
except (TypeError, IndexError):
pass
if _indicator == u"+":
self.open_ended = True
self.write_line_break()
breaks = True
start = end = 0
while end <= len(text):
ch = None
if end < len(text):
ch = text[end]
if breaks:
if ch is None or ch not in u"\n\x85\u2028\u2029":
for br in text[start:end]:
if br == u"\n":
self.write_line_break()
else:
self.write_line_break(br)
if ch is not None:
if self.root_context:
idnx = self.indent if self.indent is not None else 0
self.stream.write(u" " * (_indent + idnx))
else:
self.write_indent()
start = end
else:
if ch is None or ch in u"\n\x85\u2028\u2029":
data = text[start:end]
if bool(self.encoding):
data = data.encode(self.encoding)
self.stream.write(data)
if ch is None:
self.write_line_break()
start = end
if ch is not None:
breaks = ch in u"\n\x85\u2028\u2029"
end += 1
def write_plain(self, text, split=True):
# type: (Any, Any) -> None
if self.root_context:
if self.requested_indent is not None:
self.write_line_break()
if self.requested_indent != 0:
self.write_indent()
else:
self.open_ended = True
if not text:
return
if not self.whitespace:
data = u" "
self.column += len(data)
if self.encoding:
data = data.encode(self.encoding)
self.stream.write(data)
self.whitespace = False
self.indention = False
spaces = False
breaks = False
start = end = 0
while end <= len(text):
ch = None
if end < len(text):
ch = text[end]
if spaces:
if ch != u" ":
if start + 1 == end and self.column > self.best_width and split:
self.write_indent()
self.whitespace = False
self.indention = False
else:
data = text[start:end]
self.column += len(data)
if self.encoding:
data = data.encode(self.encoding)
self.stream.write(data)
start = end
elif breaks:
if ch not in u"\n\x85\u2028\u2029": # type: ignore
if text[start] == u"\n":
self.write_line_break()
for br in text[start:end]:
if br == u"\n":
self.write_line_break()
else:
self.write_line_break(br)
self.write_indent()
self.whitespace = False
self.indention = False
start = end
else:
if ch is None or ch in u" \n\x85\u2028\u2029":
data = text[start:end]
self.column += len(data)
if self.encoding:
data = data.encode(self.encoding)
try:
self.stream.write(data)
except: # NOQA
sys.stdout.write(repr(data) + "\n")
raise
start = end
if ch is not None:
spaces = ch == u" "
breaks = ch in u"\n\x85\u2028\u2029"
end += 1
def write_comment(self, comment, pre=False):
# type: (Any, bool) -> None
value = comment.value
# nprintf('{:02d} {:02d} {!r}'.format(self.column, comment.start_mark.column, value))
if not pre and value[-1] == "\n":
value = value[:-1]
try:
# get original column position
col = comment.start_mark.column
if comment.value and comment.value.startswith("\n"):
# never inject extra spaces if the comment starts with a newline
# and not a real comment (e.g. if you have an empty line following a key-value
col = self.column
elif col < self.column + 1:
ValueError
except ValueError:
col = self.column + 1
# nprint('post_comment', self.line, self.column, value)
try:
# at least one space if the current column >= the start column of the comment
# but not at the start of a line
nr_spaces = col - self.column
if self.column and value.strip() and nr_spaces < 1 and value[0] != "\n":
nr_spaces = 1
value = " " * nr_spaces + value
try:
if bool(self.encoding):
value = value.encode(self.encoding)
except UnicodeDecodeError:
pass
self.stream.write(value)
except TypeError:
raise
if not pre:
self.write_line_break()
def write_pre_comment(self, event):
# type: (Any) -> bool
comments = event.comment[1]
if comments is None:
return False
try:
start_events = (MappingStartEvent, SequenceStartEvent)
for comment in comments:
if isinstance(event, start_events) and getattr(
comment, "pre_done", None
):
continue
if self.column != 0:
self.write_line_break()
self.write_comment(comment, pre=True)
if isinstance(event, start_events):
comment.pre_done = True
except TypeError:
sys.stdout.write("eventtt {} {}".format(type(event), event))
raise
return True
def write_post_comment(self, event):
# type: (Any) -> bool
if self.event.comment[0] is None:
return False
comment = event.comment[0]
self.write_comment(comment)
return True
srsly-release-v2.5.1/srsly/ruamel_yaml/error.py 0000775 0000000 0000000 00000021620 14742310675 0021702 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
import warnings
import textwrap
from .compat import utf8
if False: # MYPY
from typing import Any, Dict, Optional, List, Text # NOQA
__all__ = [
"FileMark",
"StringMark",
"CommentMark",
"YAMLError",
"MarkedYAMLError",
"ReusedAnchorWarning",
"UnsafeLoaderWarning",
"MarkedYAMLWarning",
"MarkedYAMLFutureWarning",
]
class StreamMark(object):
__slots__ = "name", "index", "line", "column"
def __init__(self, name, index, line, column):
# type: (Any, int, int, int) -> None
self.name = name
self.index = index
self.line = line
self.column = column
def __str__(self):
# type: () -> Any
where = ' in "%s", line %d, column %d' % (
self.name,
self.line + 1,
self.column + 1,
)
return where
def __eq__(self, other):
# type: (Any) -> bool
if self.line != other.line or self.column != other.column:
return False
if self.name != other.name or self.index != other.index:
return False
return True
def __ne__(self, other):
# type: (Any) -> bool
return not self.__eq__(other)
class FileMark(StreamMark):
__slots__ = ()
class StringMark(StreamMark):
__slots__ = "name", "index", "line", "column", "buffer", "pointer"
def __init__(self, name, index, line, column, buffer, pointer):
# type: (Any, int, int, int, Any, Any) -> None
StreamMark.__init__(self, name, index, line, column)
self.buffer = buffer
self.pointer = pointer
def get_snippet(self, indent=4, max_length=75):
# type: (int, int) -> Any
if self.buffer is None: # always False
return None
head = ""
start = self.pointer
while start > 0 and self.buffer[start - 1] not in u"\0\r\n\x85\u2028\u2029":
start -= 1
if self.pointer - start > max_length / 2 - 1:
head = " ... "
start += 5
break
tail = ""
end = self.pointer
while (
end < len(self.buffer) and self.buffer[end] not in u"\0\r\n\x85\u2028\u2029"
):
end += 1
if end - self.pointer > max_length / 2 - 1:
tail = " ... "
end -= 5
break
snippet = utf8(self.buffer[start:end])
caret = "^"
caret = "^ (line: {})".format(self.line + 1)
return (
" " * indent
+ head
+ snippet
+ tail
+ "\n"
+ " " * (indent + self.pointer - start + len(head))
+ caret
)
def __str__(self):
# type: () -> Any
snippet = self.get_snippet()
where = ' in "%s", line %d, column %d' % (
self.name,
self.line + 1,
self.column + 1,
)
if snippet is not None:
where += ":\n" + snippet
return where
class CommentMark(object):
__slots__ = ("column",)
def __init__(self, column):
# type: (Any) -> None
self.column = column
class YAMLError(Exception):
pass
class MarkedYAMLError(YAMLError):
def __init__(
self,
context=None,
context_mark=None,
problem=None,
problem_mark=None,
note=None,
warn=None,
):
# type: (Any, Any, Any, Any, Any, Any) -> None
self.context = context
self.context_mark = context_mark
self.problem = problem
self.problem_mark = problem_mark
self.note = note
# warn is ignored
def __str__(self):
# type: () -> Any
lines = [] # type: List[str]
if self.context is not None:
lines.append(self.context)
if self.context_mark is not None and (
self.problem is None
or self.problem_mark is None
or self.context_mark.name != self.problem_mark.name
or self.context_mark.line != self.problem_mark.line
or self.context_mark.column != self.problem_mark.column
):
lines.append(str(self.context_mark))
if self.problem is not None:
lines.append(self.problem)
if self.problem_mark is not None:
lines.append(str(self.problem_mark))
if self.note is not None and self.note:
note = textwrap.dedent(self.note)
lines.append(note)
return "\n".join(lines)
class YAMLStreamError(Exception):
pass
class YAMLWarning(Warning):
pass
class MarkedYAMLWarning(YAMLWarning):
def __init__(
self,
context=None,
context_mark=None,
problem=None,
problem_mark=None,
note=None,
warn=None,
):
# type: (Any, Any, Any, Any, Any, Any) -> None
self.context = context
self.context_mark = context_mark
self.problem = problem
self.problem_mark = problem_mark
self.note = note
self.warn = warn
def __str__(self):
# type: () -> Any
lines = [] # type: List[str]
if self.context is not None:
lines.append(self.context)
if self.context_mark is not None and (
self.problem is None
or self.problem_mark is None
or self.context_mark.name != self.problem_mark.name
or self.context_mark.line != self.problem_mark.line
or self.context_mark.column != self.problem_mark.column
):
lines.append(str(self.context_mark))
if self.problem is not None:
lines.append(self.problem)
if self.problem_mark is not None:
lines.append(str(self.problem_mark))
if self.note is not None and self.note:
note = textwrap.dedent(self.note)
lines.append(note)
if self.warn is not None and self.warn:
warn = textwrap.dedent(self.warn)
lines.append(warn)
return "\n".join(lines)
class ReusedAnchorWarning(YAMLWarning):
pass
class UnsafeLoaderWarning(YAMLWarning):
text = """
The default 'Loader' for 'load(stream)' without further arguments can be unsafe.
Use 'load(stream, Loader=srsly.ruamel_yaml.Loader)' explicitly if that is OK.
Alternatively include the following in your code:
import warnings
warnings.simplefilter('ignore', srsly.ruamel_yaml.error.UnsafeLoaderWarning)
In most other cases you should consider using 'safe_load(stream)'"""
pass
warnings.simplefilter("once", UnsafeLoaderWarning)
class MantissaNoDotYAML1_1Warning(YAMLWarning):
def __init__(self, node, flt_str):
# type: (Any, Any) -> None
self.node = node
self.flt = flt_str
def __str__(self):
# type: () -> Any
line = self.node.start_mark.line
col = self.node.start_mark.column
return """
In YAML 1.1 floating point values should have a dot ('.') in their mantissa.
See the Floating-Point Language-Independent Type for YAML™ Version 1.1 specification
( http://yaml.org/type/float.html ). This dot is not required for JSON nor for YAML 1.2
Correct your float: "{}" on line: {}, column: {}
or alternatively include the following in your code:
import warnings
warnings.simplefilter('ignore', srsly.ruamel_yaml.error.MantissaNoDotYAML1_1Warning)
""".format(
self.flt, line, col
)
warnings.simplefilter("once", MantissaNoDotYAML1_1Warning)
class YAMLFutureWarning(Warning):
pass
class MarkedYAMLFutureWarning(YAMLFutureWarning):
def __init__(
self,
context=None,
context_mark=None,
problem=None,
problem_mark=None,
note=None,
warn=None,
):
# type: (Any, Any, Any, Any, Any, Any) -> None
self.context = context
self.context_mark = context_mark
self.problem = problem
self.problem_mark = problem_mark
self.note = note
self.warn = warn
def __str__(self):
# type: () -> Any
lines = [] # type: List[str]
if self.context is not None:
lines.append(self.context)
if self.context_mark is not None and (
self.problem is None
or self.problem_mark is None
or self.context_mark.name != self.problem_mark.name
or self.context_mark.line != self.problem_mark.line
or self.context_mark.column != self.problem_mark.column
):
lines.append(str(self.context_mark))
if self.problem is not None:
lines.append(self.problem)
if self.problem_mark is not None:
lines.append(str(self.problem_mark))
if self.note is not None and self.note:
note = textwrap.dedent(self.note)
lines.append(note)
if self.warn is not None and self.warn:
warn = textwrap.dedent(self.warn)
lines.append(warn)
return "\n".join(lines)
srsly-release-v2.5.1/srsly/ruamel_yaml/events.py 0000775 0000000 0000000 00000007476 14742310675 0022072 0 ustar 00root root 0000000 0000000 # coding: utf-8
# Abstract classes.
if False: # MYPY
from typing import Any, Dict, Optional, List # NOQA
def CommentCheck():
# type: () -> None
pass
class Event(object):
__slots__ = 'start_mark', 'end_mark', 'comment'
def __init__(self, start_mark=None, end_mark=None, comment=CommentCheck):
# type: (Any, Any, Any) -> None
self.start_mark = start_mark
self.end_mark = end_mark
# assert comment is not CommentCheck
if comment is CommentCheck:
comment = None
self.comment = comment
def __repr__(self):
# type: () -> Any
attributes = [
key
for key in ['anchor', 'tag', 'implicit', 'value', 'flow_style', 'style']
if hasattr(self, key)
]
arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) for key in attributes])
if self.comment not in [None, CommentCheck]:
arguments += ', comment={!r}'.format(self.comment)
return '%s(%s)' % (self.__class__.__name__, arguments)
class NodeEvent(Event):
__slots__ = ('anchor',)
def __init__(self, anchor, start_mark=None, end_mark=None, comment=None):
# type: (Any, Any, Any, Any) -> None
Event.__init__(self, start_mark, end_mark, comment)
self.anchor = anchor
class CollectionStartEvent(NodeEvent):
__slots__ = 'tag', 'implicit', 'flow_style', 'nr_items'
def __init__(
self,
anchor,
tag,
implicit,
start_mark=None,
end_mark=None,
flow_style=None,
comment=None,
nr_items=None,
):
# type: (Any, Any, Any, Any, Any, Any, Any, Optional[int]) -> None
NodeEvent.__init__(self, anchor, start_mark, end_mark, comment)
self.tag = tag
self.implicit = implicit
self.flow_style = flow_style
self.nr_items = nr_items
class CollectionEndEvent(Event):
__slots__ = ()
# Implementations.
class StreamStartEvent(Event):
__slots__ = ('encoding',)
def __init__(self, start_mark=None, end_mark=None, encoding=None, comment=None):
# type: (Any, Any, Any, Any) -> None
Event.__init__(self, start_mark, end_mark, comment)
self.encoding = encoding
class StreamEndEvent(Event):
__slots__ = ()
class DocumentStartEvent(Event):
__slots__ = 'explicit', 'version', 'tags'
def __init__(
self,
start_mark=None,
end_mark=None,
explicit=None,
version=None,
tags=None,
comment=None,
):
# type: (Any, Any, Any, Any, Any, Any) -> None
Event.__init__(self, start_mark, end_mark, comment)
self.explicit = explicit
self.version = version
self.tags = tags
class DocumentEndEvent(Event):
__slots__ = ('explicit',)
def __init__(self, start_mark=None, end_mark=None, explicit=None, comment=None):
# type: (Any, Any, Any, Any) -> None
Event.__init__(self, start_mark, end_mark, comment)
self.explicit = explicit
class AliasEvent(NodeEvent):
__slots__ = ()
class ScalarEvent(NodeEvent):
__slots__ = 'tag', 'implicit', 'value', 'style'
def __init__(
self,
anchor,
tag,
implicit,
value,
start_mark=None,
end_mark=None,
style=None,
comment=None,
):
# type: (Any, Any, Any, Any, Any, Any, Any, Any) -> None
NodeEvent.__init__(self, anchor, start_mark, end_mark, comment)
self.tag = tag
self.implicit = implicit
self.value = value
self.style = style
class SequenceStartEvent(CollectionStartEvent):
__slots__ = ()
class SequenceEndEvent(CollectionEndEvent):
__slots__ = ()
class MappingStartEvent(CollectionStartEvent):
__slots__ = ()
class MappingEndEvent(CollectionEndEvent):
__slots__ = ()
srsly-release-v2.5.1/srsly/ruamel_yaml/loader.py 0000775 0000000 0000000 00000005045 14742310675 0022022 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
from .reader import Reader
from .scanner import Scanner, RoundTripScanner
from .parser import Parser, RoundTripParser
from .composer import Composer
from .constructor import (
BaseConstructor,
SafeConstructor,
Constructor,
RoundTripConstructor,
)
from .resolver import VersionedResolver
if False: # MYPY
from typing import Any, Dict, List, Union, Optional # NOQA
from .compat import StreamTextType, VersionType # NOQA
__all__ = ["BaseLoader", "SafeLoader", "Loader", "RoundTripLoader"]
class BaseLoader(Reader, Scanner, Parser, Composer, BaseConstructor, VersionedResolver):
def __init__(self, stream, version=None, preserve_quotes=None):
# type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
Reader.__init__(self, stream, loader=self)
Scanner.__init__(self, loader=self)
Parser.__init__(self, loader=self)
Composer.__init__(self, loader=self)
BaseConstructor.__init__(self, loader=self)
VersionedResolver.__init__(self, version, loader=self)
class SafeLoader(Reader, Scanner, Parser, Composer, SafeConstructor, VersionedResolver):
def __init__(self, stream, version=None, preserve_quotes=None):
# type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
Reader.__init__(self, stream, loader=self)
Scanner.__init__(self, loader=self)
Parser.__init__(self, loader=self)
Composer.__init__(self, loader=self)
SafeConstructor.__init__(self, loader=self)
VersionedResolver.__init__(self, version, loader=self)
class Loader(Reader, Scanner, Parser, Composer, Constructor, VersionedResolver):
def __init__(self, stream, version=None, preserve_quotes=None):
raise ValueError("Unsafe loader not implemented in this library.")
class RoundTripLoader(
Reader,
RoundTripScanner,
RoundTripParser,
Composer,
RoundTripConstructor,
VersionedResolver,
):
def __init__(self, stream, version=None, preserve_quotes=None):
# type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None
# self.reader = Reader.__init__(self, stream)
Reader.__init__(self, stream, loader=self)
RoundTripScanner.__init__(self, loader=self)
RoundTripParser.__init__(self, loader=self)
Composer.__init__(self, loader=self)
RoundTripConstructor.__init__(
self, preserve_quotes=preserve_quotes, loader=self
)
VersionedResolver.__init__(self, version, loader=self)
srsly-release-v2.5.1/srsly/ruamel_yaml/main.py 0000775 0000000 0000000 00000151272 14742310675 0021504 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import, unicode_literals, print_function
import sys
import os
import warnings
import glob
from importlib import import_module
from . import resolver
from . import emitter
from . import representer
from . import parser
from . import composer
from . import constructor
from . import serializer
from . import scanner
from . import loader
from . import dumper
from . import reader
from .error import UnsafeLoaderWarning, YAMLError # NOQA
from .tokens import * # NOQA
from .events import * # NOQA
from .nodes import * # NOQA
from .loader import BaseLoader, SafeLoader, Loader, RoundTripLoader # NOQA
from .dumper import BaseDumper, SafeDumper, Dumper, RoundTripDumper # NOQA
from .compat import StringIO, BytesIO, with_metaclass, PY3, nprint
from .resolver import VersionedResolver, Resolver # NOQA
from .representer import (
BaseRepresenter,
SafeRepresenter,
Representer,
RoundTripRepresenter,
)
from .constructor import (
BaseConstructor,
SafeConstructor,
Constructor,
RoundTripConstructor,
)
from .loader import Loader as UnsafeLoader
if False: # MYPY
from typing import List, Set, Dict, Union, Any, Callable, Optional, Text # NOQA
from .compat import StreamType, StreamTextType, VersionType # NOQA
if PY3:
from pathlib import Path
else:
Path = Any
try:
from _ruamel_yaml import CParser, CEmitter # type: ignore
except: # NOQA
CParser = CEmitter = None
# import io
enforce = object()
# YAML is an acronym, i.e. spoken: rhymes with "camel". And thus a
# subset of abbreviations, which should be all caps according to PEP8
class YAML(object):
def __init__(
self,
_kw=enforce,
typ=None,
pure=False,
output=None,
plug_ins=None, # input=None,
):
# type: (Any, Optional[Text], Any, Any, Any) -> None
"""
_kw: not used, forces keyword arguments in 2.7 (in 3 you can do (*, safe_load=..)
typ: 'rt'/None -> RoundTripLoader/RoundTripDumper, (default)
'safe' -> SafeLoader/SafeDumper,
'unsafe' -> normal/unsafe Loader/Dumper
'base' -> baseloader
pure: if True only use Python modules
input/output: needed to work as context manager
plug_ins: a list of plug-in files
"""
if _kw is not enforce:
raise TypeError(
"{}.__init__() takes no positional argument but at least "
"one was given ({!r})".format(self.__class__.__name__, _kw)
)
self.typ = ["rt"] if typ is None else (typ if isinstance(typ, list) else [typ])
self.pure = pure
# self._input = input
self._output = output
self._context_manager = None # type: Any
self.plug_ins = [] # type: List[Any]
for pu in ([] if plug_ins is None else plug_ins) + self.official_plug_ins():
file_name = pu.replace(os.sep, ".")
self.plug_ins.append(import_module(file_name))
self.Resolver = resolver.VersionedResolver # type: Any
self.allow_unicode = True
self.Reader = None # type: Any
self.Representer = None # type: Any
self.Constructor = None # type: Any
self.Scanner = None # type: Any
self.Serializer = None # type: Any
self.default_flow_style = None # type: Any
typ_found = 1
setup_rt = False
if "rt" in self.typ:
setup_rt = True
elif "safe" in self.typ:
self.Emitter = emitter.Emitter if pure or CEmitter is None else CEmitter
self.Representer = representer.SafeRepresenter
self.Parser = parser.Parser if pure or CParser is None else CParser
self.Composer = composer.Composer
self.Constructor = constructor.SafeConstructor
elif "base" in self.typ:
self.Emitter = emitter.Emitter
self.Representer = representer.BaseRepresenter
self.Parser = parser.Parser if pure or CParser is None else CParser
self.Composer = composer.Composer
self.Constructor = constructor.BaseConstructor
elif "unsafe" in self.typ:
self.Emitter = emitter.Emitter if pure or CEmitter is None else CEmitter
self.Representer = representer.Representer
self.Parser = parser.Parser if pure or CParser is None else CParser
self.Composer = composer.Composer
self.Constructor = constructor.Constructor
else:
setup_rt = True
typ_found = 0
if setup_rt:
self.default_flow_style = False
# no optimized rt-dumper yet
self.Emitter = emitter.Emitter
self.Serializer = serializer.Serializer
self.Representer = representer.RoundTripRepresenter
self.Scanner = scanner.RoundTripScanner
# no optimized rt-parser yet
self.Parser = parser.RoundTripParser
self.Composer = composer.Composer
self.Constructor = constructor.RoundTripConstructor
del setup_rt
self.stream = None
self.canonical = None
self.old_indent = None
self.width = None
self.line_break = None
self.map_indent = None
self.sequence_indent = None
self.sequence_dash_offset = 0
self.compact_seq_seq = None
self.compact_seq_map = None
self.sort_base_mapping_type_on_output = None # default: sort
self.top_level_colon_align = None
self.prefix_colon = None
self.version = None
self.preserve_quotes = None
self.allow_duplicate_keys = False # duplicate keys in map, set
self.encoding = "utf-8"
self.explicit_start = None
self.explicit_end = None
self.tags = None
self.default_style = None
self.top_level_block_style_scalar_no_indent_error_1_1 = False
# directives end indicator with single scalar document
self.scalar_after_indicator = None
# [a, b: 1, c: {d: 2}] vs. [a, {b: 1}, {c: {d: 2}}]
self.brace_single_entry_mapping_in_flow_sequence = False
for module in self.plug_ins:
if getattr(module, "typ", None) in self.typ:
typ_found += 1
module.init_typ(self)
break
if typ_found == 0:
raise NotImplementedError(
'typ "{}"not recognised (need to install plug-in?)'.format(self.typ)
)
@property
def reader(self):
# type: () -> Any
try:
return self._reader # type: ignore
except AttributeError:
self._reader = self.Reader(None, loader=self)
return self._reader
@property
def scanner(self):
# type: () -> Any
try:
return self._scanner # type: ignore
except AttributeError:
self._scanner = self.Scanner(loader=self)
return self._scanner
@property
def parser(self):
# type: () -> Any
attr = "_" + sys._getframe().f_code.co_name
if not hasattr(self, attr):
if self.Parser is not CParser:
setattr(self, attr, self.Parser(loader=self))
else:
if getattr(self, "_stream", None) is None:
# wait for the stream
return None
else:
# if not hasattr(self._stream, 'read') and hasattr(self._stream, 'open'):
# # pathlib.Path() instance
# setattr(self, attr, CParser(self._stream))
# else:
setattr(self, attr, CParser(self._stream))
# self._parser = self._composer = self
# nprint('scanner', self.loader.scanner)
return getattr(self, attr)
@property
def composer(self):
# type: () -> Any
attr = "_" + sys._getframe().f_code.co_name
if not hasattr(self, attr):
setattr(self, attr, self.Composer(loader=self))
return getattr(self, attr)
@property
def constructor(self):
# type: () -> Any
attr = "_" + sys._getframe().f_code.co_name
if not hasattr(self, attr):
cnst = self.Constructor(preserve_quotes=self.preserve_quotes, loader=self)
cnst.allow_duplicate_keys = self.allow_duplicate_keys
setattr(self, attr, cnst)
return getattr(self, attr)
@property
def resolver(self):
# type: () -> Any
attr = "_" + sys._getframe().f_code.co_name
if not hasattr(self, attr):
setattr(self, attr, self.Resolver(version=self.version, loader=self))
return getattr(self, attr)
@property
def emitter(self):
# type: () -> Any
attr = "_" + sys._getframe().f_code.co_name
if not hasattr(self, attr):
if self.Emitter is not CEmitter:
_emitter = self.Emitter(
None,
canonical=self.canonical,
indent=self.old_indent,
width=self.width,
allow_unicode=self.allow_unicode,
line_break=self.line_break,
prefix_colon=self.prefix_colon,
brace_single_entry_mapping_in_flow_sequence=self.brace_single_entry_mapping_in_flow_sequence, # NOQA
dumper=self,
)
setattr(self, attr, _emitter)
if self.map_indent is not None:
_emitter.best_map_indent = self.map_indent
if self.sequence_indent is not None:
_emitter.best_sequence_indent = self.sequence_indent
if self.sequence_dash_offset is not None:
_emitter.sequence_dash_offset = self.sequence_dash_offset
# _emitter.block_seq_indent = self.sequence_dash_offset
if self.compact_seq_seq is not None:
_emitter.compact_seq_seq = self.compact_seq_seq
if self.compact_seq_map is not None:
_emitter.compact_seq_map = self.compact_seq_map
else:
if getattr(self, "_stream", None) is None:
# wait for the stream
return None
return None
return getattr(self, attr)
@property
def serializer(self):
# type: () -> Any
attr = "_" + sys._getframe().f_code.co_name
if not hasattr(self, attr):
setattr(
self,
attr,
self.Serializer(
encoding=self.encoding,
explicit_start=self.explicit_start,
explicit_end=self.explicit_end,
version=self.version,
tags=self.tags,
dumper=self,
),
)
return getattr(self, attr)
@property
def representer(self):
# type: () -> Any
attr = "_" + sys._getframe().f_code.co_name
if not hasattr(self, attr):
repres = self.Representer(
default_style=self.default_style,
default_flow_style=self.default_flow_style,
dumper=self,
)
if self.sort_base_mapping_type_on_output is not None:
repres.sort_base_mapping_type_on_output = (
self.sort_base_mapping_type_on_output
)
setattr(self, attr, repres)
return getattr(self, attr)
# separate output resolver?
# def load(self, stream=None):
# if self._context_manager:
# if not self._input:
# raise TypeError("Missing input stream while dumping from context manager")
# for data in self._context_manager.load():
# yield data
# return
# if stream is None:
# raise TypeError("Need a stream argument when not loading from context manager")
# return self.load_one(stream)
def load(self, stream):
# type: (Union[Path, StreamTextType]) -> Any
"""
at this point you either have the non-pure Parser (which has its own reader and
scanner) or you have the pure Parser.
If the pure Parser is set, then set the Reader and Scanner, if not already set.
If either the Scanner or Reader are set, you cannot use the non-pure Parser,
so reset it to the pure parser and set the Reader resp. Scanner if necessary
"""
if not hasattr(stream, "read") and hasattr(stream, "open"):
# pathlib.Path() instance
with stream.open("rb") as fp:
return self.load(fp)
constructor, parser = self.get_constructor_parser(stream)
try:
return constructor.get_single_data()
finally:
parser.dispose()
try:
self._reader.reset_reader()
except AttributeError:
pass
try:
self._scanner.reset_scanner()
except AttributeError:
pass
def load_all(self, stream, _kw=enforce): # , skip=None):
# type: (Union[Path, StreamTextType], Any) -> Any
if _kw is not enforce:
raise TypeError(
"{}.__init__() takes no positional argument but at least "
"one was given ({!r})".format(self.__class__.__name__, _kw)
)
if not hasattr(stream, "read") and hasattr(stream, "open"):
# pathlib.Path() instance
with stream.open("r") as fp:
for d in self.load_all(fp, _kw=enforce):
yield d
return
# if skip is None:
# skip = []
# elif isinstance(skip, int):
# skip = [skip]
constructor, parser = self.get_constructor_parser(stream)
try:
while constructor.check_data():
yield constructor.get_data()
finally:
parser.dispose()
try:
self._reader.reset_reader()
except AttributeError:
pass
try:
self._scanner.reset_scanner()
except AttributeError:
pass
def get_constructor_parser(self, stream):
# type: (StreamTextType) -> Any
"""
the old cyaml needs special setup, and therefore the stream
"""
if self.Parser is not CParser:
if self.Reader is None:
self.Reader = reader.Reader
if self.Scanner is None:
self.Scanner = scanner.Scanner
self.reader.stream = stream
else:
if self.Reader is not None:
if self.Scanner is None:
self.Scanner = scanner.Scanner
self.Parser = parser.Parser
self.reader.stream = stream
elif self.Scanner is not None:
if self.Reader is None:
self.Reader = reader.Reader
self.Parser = parser.Parser
self.reader.stream = stream
else:
# combined C level reader>scanner>parser
# does some calls to the resolver, e.g. BaseResolver.descend_resolver
# if you just initialise the CParser, to much of resolver.py
# is actually used
rslvr = self.Resolver
# if rslvr is srsly.ruamel_yaml.resolver.VersionedResolver:
# rslvr = srsly.ruamel_yaml.resolver.Resolver
class XLoader(self.Parser, self.Constructor, rslvr): # type: ignore
def __init__(
selfx, stream, version=self.version, preserve_quotes=None
):
# type: (StreamTextType, Optional[VersionType], Optional[bool]) -> None # NOQA
CParser.__init__(selfx, stream)
selfx._parser = selfx._composer = selfx
self.Constructor.__init__(selfx, loader=selfx)
selfx.allow_duplicate_keys = self.allow_duplicate_keys
rslvr.__init__(selfx, version=version, loadumper=selfx)
self._stream = stream
loader = XLoader(stream)
return loader, loader
return self.constructor, self.parser
def dump(self, data, stream=None, _kw=enforce, transform=None):
# type: (Any, Union[Path, StreamType], Any, Any) -> Any
if self._context_manager:
if not self._output:
raise TypeError(
"Missing output stream while dumping from context manager"
)
if _kw is not enforce:
raise TypeError(
"{}.dump() takes one positional argument but at least "
"two were given ({!r})".format(self.__class__.__name__, _kw)
)
if transform is not None:
raise TypeError(
"{}.dump() in the context manager cannot have transform keyword "
"".format(self.__class__.__name__)
)
self._context_manager.dump(data)
else: # old style
if stream is None:
raise TypeError(
"Need a stream argument when not dumping from context manager"
)
return self.dump_all([data], stream, _kw, transform=transform)
def dump_all(self, documents, stream, _kw=enforce, transform=None):
# type: (Any, Union[Path, StreamType], Any, Any) -> Any
if self._context_manager:
raise NotImplementedError
if _kw is not enforce:
raise TypeError(
"{}.dump(_all) takes two positional argument but at least "
"three were given ({!r})".format(self.__class__.__name__, _kw)
)
self._output = stream
self._context_manager = YAMLContextManager(self, transform=transform)
for data in documents:
self._context_manager.dump(data)
self._context_manager.teardown_output()
self._output = None
self._context_manager = None
def Xdump_all(self, documents, stream, _kw=enforce, transform=None):
# type: (Any, Union[Path, StreamType], Any, Any) -> Any
"""
Serialize a sequence of Python objects into a YAML stream.
"""
if not hasattr(stream, "write") and hasattr(stream, "open"):
# pathlib.Path() instance
with stream.open("w") as fp:
return self.dump_all(documents, fp, _kw, transform=transform)
if _kw is not enforce:
raise TypeError(
"{}.dump(_all) takes two positional argument but at least "
"three were given ({!r})".format(self.__class__.__name__, _kw)
)
# The stream should have the methods `write` and possibly `flush`.
if self.top_level_colon_align is True:
tlca = max([len(str(x)) for x in documents[0]]) # type: Any
else:
tlca = self.top_level_colon_align
if transform is not None:
fstream = stream
if self.encoding is None:
stream = StringIO()
else:
stream = BytesIO()
serializer, representer, emitter = self.get_serializer_representer_emitter(
stream, tlca
)
try:
self.serializer.open()
for data in documents:
try:
self.representer.represent(data)
except AttributeError:
# nprint(dir(dumper._representer))
raise
self.serializer.close()
finally:
try:
self.emitter.dispose()
except AttributeError:
raise
# self.dumper.dispose() # cyaml
delattr(self, "_serializer")
delattr(self, "_emitter")
if transform:
val = stream.getvalue()
if self.encoding:
val = val.decode(self.encoding)
if fstream is None:
transform(val)
else:
fstream.write(transform(val))
return None
def get_serializer_representer_emitter(self, stream, tlca):
# type: (StreamType, Any) -> Any
# we have only .Serializer to deal with (vs .Reader & .Scanner), much simpler
if self.Emitter is not CEmitter:
if self.Serializer is None:
self.Serializer = serializer.Serializer
self.emitter.stream = stream
self.emitter.top_level_colon_align = tlca
if self.scalar_after_indicator is not None:
self.emitter.scalar_after_indicator = self.scalar_after_indicator
return self.serializer, self.representer, self.emitter
if self.Serializer is not None:
# cannot set serializer with CEmitter
self.Emitter = emitter.Emitter
self.emitter.stream = stream
self.emitter.top_level_colon_align = tlca
if self.scalar_after_indicator is not None:
self.emitter.scalar_after_indicator = self.scalar_after_indicator
return self.serializer, self.representer, self.emitter
# C routines
rslvr = resolver.BaseResolver if "base" in self.typ else resolver.Resolver
class XDumper(CEmitter, self.Representer, rslvr): # type: ignore
def __init__(
selfx,
stream,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=None,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (StreamType, Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> None # NOQA
CEmitter.__init__(
selfx,
stream,
canonical=canonical,
indent=indent,
width=width,
encoding=encoding,
allow_unicode=allow_unicode,
line_break=line_break,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
)
selfx._emitter = selfx._serializer = selfx._representer = selfx
self.Representer.__init__(
selfx,
default_style=default_style,
default_flow_style=default_flow_style,
)
rslvr.__init__(selfx)
self._stream = stream
dumper = XDumper(
stream,
default_style=self.default_style,
default_flow_style=self.default_flow_style,
canonical=self.canonical,
indent=self.old_indent,
width=self.width,
allow_unicode=self.allow_unicode,
line_break=self.line_break,
explicit_start=self.explicit_start,
explicit_end=self.explicit_end,
version=self.version,
tags=self.tags,
)
self._emitter = self._serializer = dumper
return dumper, dumper, dumper
# basic types
def map(self, **kw):
# type: (Any) -> Any
if "rt" in self.typ:
from .comments import CommentedMap
return CommentedMap(**kw)
else:
return dict(**kw)
def seq(self, *args):
# type: (Any) -> Any
if "rt" in self.typ:
from .comments import CommentedSeq
return CommentedSeq(*args)
else:
return list(*args)
# helpers
def official_plug_ins(self):
# type: () -> Any
bd = os.path.dirname(__file__)
gpbd = os.path.dirname(os.path.dirname(bd))
res = [x.replace(gpbd, "")[1:-3] for x in glob.glob(bd + "/*/__plug_in__.py")]
return res
def register_class(self, cls):
# type:(Any) -> Any
"""
register a class for dumping loading
- if it has attribute yaml_tag use that to register, else use class name
- if it has methods to_yaml/from_yaml use those to dump/load else dump attributes
as mapping
"""
tag = getattr(cls, "yaml_tag", "!" + cls.__name__)
try:
self.representer.add_representer(cls, cls.to_yaml)
except AttributeError:
def t_y(representer, data):
# type: (Any, Any) -> Any
return representer.represent_yaml_object(
tag, data, cls, flow_style=representer.default_flow_style
)
self.representer.add_representer(cls, t_y)
try:
self.constructor.add_constructor(tag, cls.from_yaml)
except AttributeError:
def f_y(constructor, node):
# type: (Any, Any) -> Any
return constructor.construct_yaml_object(node, cls)
self.constructor.add_constructor(tag, f_y)
return cls
def parse(self, stream):
# type: (StreamTextType) -> Any
"""
Parse a YAML stream and produce parsing events.
"""
_, parser = self.get_constructor_parser(stream)
try:
while parser.check_event():
yield parser.get_event()
finally:
parser.dispose()
try:
self._reader.reset_reader()
except AttributeError:
pass
try:
self._scanner.reset_scanner()
except AttributeError:
pass
# ### context manager
def __enter__(self):
# type: () -> Any
self._context_manager = YAMLContextManager(self)
return self
def __exit__(self, typ, value, traceback):
# type: (Any, Any, Any) -> None
if typ:
nprint("typ", typ)
self._context_manager.teardown_output()
# self._context_manager.teardown_input()
self._context_manager = None
# ### backwards compatibility
def _indent(self, mapping=None, sequence=None, offset=None):
# type: (Any, Any, Any) -> None
if mapping is not None:
self.map_indent = mapping
if sequence is not None:
self.sequence_indent = sequence
if offset is not None:
self.sequence_dash_offset = offset
@property
def indent(self):
# type: () -> Any
return self._indent
@indent.setter
def indent(self, val):
# type: (Any) -> None
self.old_indent = val
@property
def block_seq_indent(self):
# type: () -> Any
return self.sequence_dash_offset
@block_seq_indent.setter
def block_seq_indent(self, val):
# type: (Any) -> None
self.sequence_dash_offset = val
def compact(self, seq_seq=None, seq_map=None):
# type: (Any, Any) -> None
self.compact_seq_seq = seq_seq
self.compact_seq_map = seq_map
class YAMLContextManager(object):
def __init__(self, yaml, transform=None):
# type: (Any, Any) -> None # used to be: (Any, Optional[Callable]) -> None
self._yaml = yaml
self._output_inited = False
self._output_path = None
self._output = self._yaml._output
self._transform = transform
# self._input_inited = False
# self._input = input
# self._input_path = None
# self._transform = yaml.transform
# self._fstream = None
if not hasattr(self._output, "write") and hasattr(self._output, "open"):
# pathlib.Path() instance, open with the same mode
self._output_path = self._output
self._output = self._output_path.open("w")
# if not hasattr(self._stream, 'write') and hasattr(stream, 'open'):
# if not hasattr(self._input, 'read') and hasattr(self._input, 'open'):
# # pathlib.Path() instance, open with the same mode
# self._input_path = self._input
# self._input = self._input_path.open('r')
if self._transform is not None:
self._fstream = self._output
if self._yaml.encoding is None:
self._output = StringIO()
else:
self._output = BytesIO()
def teardown_output(self):
# type: () -> None
if self._output_inited:
self._yaml.serializer.close()
else:
return
try:
self._yaml.emitter.dispose()
except AttributeError:
raise
# self.dumper.dispose() # cyaml
try:
delattr(self._yaml, "_serializer")
delattr(self._yaml, "_emitter")
except AttributeError:
raise
if self._transform:
val = self._output.getvalue()
if self._yaml.encoding:
val = val.decode(self._yaml.encoding)
if self._fstream is None:
self._transform(val)
else:
self._fstream.write(self._transform(val))
self._fstream.flush()
self._output = self._fstream # maybe not necessary
if self._output_path is not None:
self._output.close()
def init_output(self, first_data):
# type: (Any) -> None
if self._yaml.top_level_colon_align is True:
tlca = max([len(str(x)) for x in first_data]) # type: Any
else:
tlca = self._yaml.top_level_colon_align
self._yaml.get_serializer_representer_emitter(self._output, tlca)
self._yaml.serializer.open()
self._output_inited = True
def dump(self, data):
# type: (Any) -> None
if not self._output_inited:
self.init_output(data)
try:
self._yaml.representer.represent(data)
except AttributeError:
# nprint(dir(dumper._representer))
raise
# def teardown_input(self):
# pass
#
# def init_input(self):
# # set the constructor and parser on YAML() instance
# self._yaml.get_constructor_parser(stream)
#
# def load(self):
# if not self._input_inited:
# self.init_input()
# try:
# while self._yaml.constructor.check_data():
# yield self._yaml.constructor.get_data()
# finally:
# parser.dispose()
# try:
# self._reader.reset_reader() # type: ignore
# except AttributeError:
# pass
# try:
# self._scanner.reset_scanner() # type: ignore
# except AttributeError:
# pass
def yaml_object(yml):
# type: (Any) -> Any
""" decorator for classes that needs to dump/load objects
The tag for such objects is taken from the class attribute yaml_tag (or the
class name in lowercase in case unavailable)
If methods to_yaml and/or from_yaml are available, these are called for dumping resp.
loading, default routines (dumping a mapping of the attributes) used otherwise.
"""
def yo_deco(cls):
# type: (Any) -> Any
tag = getattr(cls, "yaml_tag", "!" + cls.__name__)
try:
yml.representer.add_representer(cls, cls.to_yaml)
except AttributeError:
def t_y(representer, data):
# type: (Any, Any) -> Any
return representer.represent_yaml_object(
tag, data, cls, flow_style=representer.default_flow_style
)
yml.representer.add_representer(cls, t_y)
try:
yml.constructor.add_constructor(tag, cls.from_yaml)
except AttributeError:
def f_y(constructor, node):
# type: (Any, Any) -> Any
return constructor.construct_yaml_object(node, cls)
yml.constructor.add_constructor(tag, f_y)
return cls
return yo_deco
########################################################################################
def scan(stream, Loader=Loader):
# type: (StreamTextType, Any) -> Any
"""
Scan a YAML stream and produce scanning tokens.
"""
loader = Loader(stream)
try:
while loader.scanner.check_token():
yield loader.scanner.get_token()
finally:
loader._parser.dispose()
def parse(stream, Loader=Loader):
# type: (StreamTextType, Any) -> Any
"""
Parse a YAML stream and produce parsing events.
"""
loader = Loader(stream)
try:
while loader._parser.check_event():
yield loader._parser.get_event()
finally:
loader._parser.dispose()
def compose(stream, Loader=Loader):
# type: (StreamTextType, Any) -> Any
"""
Parse the first YAML document in a stream
and produce the corresponding representation tree.
"""
loader = Loader(stream)
try:
return loader.get_single_node()
finally:
loader.dispose()
def compose_all(stream, Loader=Loader):
# type: (StreamTextType, Any) -> Any
"""
Parse all YAML documents in a stream
and produce corresponding representation trees.
"""
loader = Loader(stream)
try:
while loader.check_node():
yield loader._composer.get_node()
finally:
loader._parser.dispose()
def load(stream, Loader=None, version=None, preserve_quotes=None):
# type: (StreamTextType, Any, Optional[VersionType], Any) -> Any
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
"""
if Loader is None:
warnings.warn(UnsafeLoaderWarning.text, UnsafeLoaderWarning, stacklevel=2)
Loader = UnsafeLoader
loader = Loader(stream, version, preserve_quotes=preserve_quotes)
try:
return loader._constructor.get_single_data()
finally:
loader._parser.dispose()
try:
loader._reader.reset_reader()
except AttributeError:
pass
try:
loader._scanner.reset_scanner()
except AttributeError:
pass
def load_all(stream, Loader=None, version=None, preserve_quotes=None):
# type: (Optional[StreamTextType], Any, Optional[VersionType], Optional[bool]) -> Any # NOQA
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
"""
if Loader is None:
warnings.warn(UnsafeLoaderWarning.text, UnsafeLoaderWarning, stacklevel=2)
Loader = UnsafeLoader
loader = Loader(stream, version, preserve_quotes=preserve_quotes)
try:
while loader._constructor.check_data():
yield loader._constructor.get_data()
finally:
loader._parser.dispose()
try:
loader._reader.reset_reader()
except AttributeError:
pass
try:
loader._scanner.reset_scanner()
except AttributeError:
pass
def safe_load(stream, version=None):
# type: (StreamTextType, Optional[VersionType]) -> Any
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
Resolve only basic YAML tags.
"""
return load(stream, SafeLoader, version)
def safe_load_all(stream, version=None):
# type: (StreamTextType, Optional[VersionType]) -> Any
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
Resolve only basic YAML tags.
"""
return load_all(stream, SafeLoader, version)
def round_trip_load(stream, version=None, preserve_quotes=None):
# type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any
"""
Parse the first YAML document in a stream
and produce the corresponding Python object.
Resolve only basic YAML tags.
"""
return load(stream, RoundTripLoader, version, preserve_quotes=preserve_quotes)
def round_trip_load_all(stream, version=None, preserve_quotes=None):
# type: (StreamTextType, Optional[VersionType], Optional[bool]) -> Any
"""
Parse all YAML documents in a stream
and produce corresponding Python objects.
Resolve only basic YAML tags.
"""
return load_all(stream, RoundTripLoader, version, preserve_quotes=preserve_quotes)
def emit(
events,
stream=None,
Dumper=Dumper,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
):
# type: (Any, Optional[StreamType], Any, Optional[bool], Union[int, None], Optional[int], Optional[bool], Any) -> Any # NOQA
"""
Emit YAML parsing events into a stream.
If stream is None, return the produced string instead.
"""
getvalue = None
if stream is None:
stream = StringIO()
getvalue = stream.getvalue
dumper = Dumper(
stream,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
)
try:
for event in events:
dumper.emit(event)
finally:
try:
dumper._emitter.dispose()
except AttributeError:
raise
dumper.dispose() # cyaml
if getvalue is not None:
return getvalue()
enc = None if PY3 else "utf-8"
def serialize_all(
nodes,
stream=None,
Dumper=Dumper,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=enc,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
):
# type: (Any, Optional[StreamType], Any, Any, Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any) -> Any # NOQA
"""
Serialize a sequence of representation trees into a YAML stream.
If stream is None, return the produced string instead.
"""
getvalue = None
if stream is None:
if encoding is None:
stream = StringIO()
else:
stream = BytesIO()
getvalue = stream.getvalue
dumper = Dumper(
stream,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
encoding=encoding,
version=version,
tags=tags,
explicit_start=explicit_start,
explicit_end=explicit_end,
)
try:
dumper._serializer.open()
for node in nodes:
dumper.serialize(node)
dumper._serializer.close()
finally:
try:
dumper._emitter.dispose()
except AttributeError:
raise
dumper.dispose() # cyaml
if getvalue is not None:
return getvalue()
def serialize(node, stream=None, Dumper=Dumper, **kwds):
# type: (Any, Optional[StreamType], Any, Any) -> Any
"""
Serialize a representation tree into a YAML stream.
If stream is None, return the produced string instead.
"""
return serialize_all([node], stream, Dumper=Dumper, **kwds)
def dump_all(
documents,
stream=None,
Dumper=Dumper,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=enc,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Any, Any, Any, Any, Any) -> Optional[str] # NOQA
"""
Serialize a sequence of Python objects into a YAML stream.
If stream is None, return the produced string instead.
"""
getvalue = None
if top_level_colon_align is True:
top_level_colon_align = max([len(str(x)) for x in documents[0]])
if stream is None:
if encoding is None:
stream = StringIO()
else:
stream = BytesIO()
getvalue = stream.getvalue
dumper = Dumper(
stream,
default_style=default_style,
default_flow_style=default_flow_style,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
encoding=encoding,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
block_seq_indent=block_seq_indent,
top_level_colon_align=top_level_colon_align,
prefix_colon=prefix_colon,
)
try:
dumper._serializer.open()
for data in documents:
try:
dumper._representer.represent(data)
except AttributeError:
# nprint(dir(dumper._representer))
raise
dumper._serializer.close()
finally:
try:
dumper._emitter.dispose()
except AttributeError:
raise
dumper.dispose() # cyaml
if getvalue is not None:
return getvalue()
return None
def dump(
data,
stream=None,
Dumper=Dumper,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=enc,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
):
# type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any) -> Optional[str] # NOQA
"""
Serialize a Python object into a YAML stream.
If stream is None, return the produced string instead.
default_style ∈ None, '', '"', "'", '|', '>'
"""
return dump_all(
[data],
stream,
Dumper=Dumper,
default_style=default_style,
default_flow_style=default_flow_style,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
encoding=encoding,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
block_seq_indent=block_seq_indent,
)
def safe_dump_all(documents, stream=None, **kwds):
# type: (Any, Optional[StreamType], Any) -> Optional[str]
"""
Serialize a sequence of Python objects into a YAML stream.
Produce only basic YAML tags.
If stream is None, return the produced string instead.
"""
return dump_all(documents, stream, Dumper=SafeDumper, **kwds)
def safe_dump(data, stream=None, **kwds):
# type: (Any, Optional[StreamType], Any) -> Optional[str]
"""
Serialize a Python object into a YAML stream.
Produce only basic YAML tags.
If stream is None, return the produced string instead.
"""
return dump_all([data], stream, Dumper=SafeDumper, **kwds)
def round_trip_dump(
data,
stream=None,
Dumper=RoundTripDumper,
default_style=None,
default_flow_style=None,
canonical=None,
indent=None,
width=None,
allow_unicode=None,
line_break=None,
encoding=enc,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
):
# type: (Any, Optional[StreamType], Any, Any, Any, Optional[bool], Optional[int], Optional[int], Optional[bool], Any, Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any, Any, Any) -> Optional[str] # NOQA
allow_unicode = True if allow_unicode is None else allow_unicode
return dump_all(
[data],
stream,
Dumper=Dumper,
default_style=default_style,
default_flow_style=default_flow_style,
canonical=canonical,
indent=indent,
width=width,
allow_unicode=allow_unicode,
line_break=line_break,
encoding=encoding,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
tags=tags,
block_seq_indent=block_seq_indent,
top_level_colon_align=top_level_colon_align,
prefix_colon=prefix_colon,
)
# Loader/Dumper are no longer composites, to get to the associated
# Resolver()/Representer(), etc., you need to instantiate the class
def add_implicit_resolver(
tag, regexp, first=None, Loader=None, Dumper=None, resolver=Resolver
):
# type: (Any, Any, Any, Any, Any, Any) -> None
"""
Add an implicit scalar detector.
If an implicit scalar value matches the given regexp,
the corresponding tag is assigned to the scalar.
first is a sequence of possible initial characters or None.
"""
if Loader is None and Dumper is None:
resolver.add_implicit_resolver(tag, regexp, first)
return
if Loader:
if hasattr(Loader, "add_implicit_resolver"):
Loader.add_implicit_resolver(tag, regexp, first)
elif issubclass(
Loader, (BaseLoader, SafeLoader, loader.Loader, RoundTripLoader)
):
Resolver.add_implicit_resolver(tag, regexp, first)
else:
raise NotImplementedError
if Dumper:
if hasattr(Dumper, "add_implicit_resolver"):
Dumper.add_implicit_resolver(tag, regexp, first)
elif issubclass(
Dumper, (BaseDumper, SafeDumper, dumper.Dumper, RoundTripDumper)
):
Resolver.add_implicit_resolver(tag, regexp, first)
else:
raise NotImplementedError
# this code currently not tested
def add_path_resolver(
tag, path, kind=None, Loader=None, Dumper=None, resolver=Resolver
):
# type: (Any, Any, Any, Any, Any, Any) -> None
"""
Add a path based resolver for the given tag.
A path is a list of keys that forms a path
to a node in the representation tree.
Keys can be string values, integers, or None.
"""
if Loader is None and Dumper is None:
resolver.add_path_resolver(tag, path, kind)
return
if Loader:
if hasattr(Loader, "add_path_resolver"):
Loader.add_path_resolver(tag, path, kind)
elif issubclass(
Loader, (BaseLoader, SafeLoader, loader.Loader, RoundTripLoader)
):
Resolver.add_path_resolver(tag, path, kind)
else:
raise NotImplementedError
if Dumper:
if hasattr(Dumper, "add_path_resolver"):
Dumper.add_path_resolver(tag, path, kind)
elif issubclass(
Dumper, (BaseDumper, SafeDumper, dumper.Dumper, RoundTripDumper)
):
Resolver.add_path_resolver(tag, path, kind)
else:
raise NotImplementedError
def add_constructor(tag, object_constructor, Loader=None, constructor=Constructor):
# type: (Any, Any, Any, Any) -> None
"""
Add an object constructor for the given tag.
object_onstructor is a function that accepts a Loader instance
and a node object and produces the corresponding Python object.
"""
if Loader is None:
constructor.add_constructor(tag, object_constructor)
else:
if hasattr(Loader, "add_constructor"):
Loader.add_constructor(tag, object_constructor)
return
if issubclass(Loader, BaseLoader):
BaseConstructor.add_constructor(tag, object_constructor)
elif issubclass(Loader, SafeLoader):
SafeConstructor.add_constructor(tag, object_constructor)
elif issubclass(Loader, Loader):
Constructor.add_constructor(tag, object_constructor)
elif issubclass(Loader, RoundTripLoader):
RoundTripConstructor.add_constructor(tag, object_constructor)
else:
raise NotImplementedError
def add_multi_constructor(
tag_prefix, multi_constructor, Loader=None, constructor=Constructor
):
# type: (Any, Any, Any, Any) -> None
"""
Add a multi-constructor for the given tag prefix.
Multi-constructor is called for a node if its tag starts with tag_prefix.
Multi-constructor accepts a Loader instance, a tag suffix,
and a node object and produces the corresponding Python object.
"""
if Loader is None:
constructor.add_multi_constructor(tag_prefix, multi_constructor)
else:
if False and hasattr(Loader, "add_multi_constructor"):
Loader.add_multi_constructor(tag_prefix, constructor)
return
if issubclass(Loader, BaseLoader):
BaseConstructor.add_multi_constructor(tag_prefix, multi_constructor)
elif issubclass(Loader, SafeLoader):
SafeConstructor.add_multi_constructor(tag_prefix, multi_constructor)
elif issubclass(Loader, loader.Loader):
Constructor.add_multi_constructor(tag_prefix, multi_constructor)
elif issubclass(Loader, RoundTripLoader):
RoundTripConstructor.add_multi_constructor(tag_prefix, multi_constructor)
else:
raise NotImplementedError
def add_representer(
data_type, object_representer, Dumper=None, representer=Representer
):
# type: (Any, Any, Any, Any) -> None
"""
Add a representer for the given type.
object_representer is a function accepting a Dumper instance
and an instance of the given data type
and producing the corresponding representation node.
"""
if Dumper is None:
representer.add_representer(data_type, object_representer)
else:
if hasattr(Dumper, "add_representer"):
Dumper.add_representer(data_type, object_representer)
return
if issubclass(Dumper, BaseDumper):
BaseRepresenter.add_representer(data_type, object_representer)
elif issubclass(Dumper, SafeDumper):
SafeRepresenter.add_representer(data_type, object_representer)
elif issubclass(Dumper, Dumper):
Representer.add_representer(data_type, object_representer)
elif issubclass(Dumper, RoundTripDumper):
RoundTripRepresenter.add_representer(data_type, object_representer)
else:
raise NotImplementedError
# this code currently not tested
def add_multi_representer(
data_type, multi_representer, Dumper=None, representer=Representer
):
# type: (Any, Any, Any, Any) -> None
"""
Add a representer for the given type.
multi_representer is a function accepting a Dumper instance
and an instance of the given data type or subtype
and producing the corresponding representation node.
"""
if Dumper is None:
representer.add_multi_representer(data_type, multi_representer)
else:
if hasattr(Dumper, "add_multi_representer"):
Dumper.add_multi_representer(data_type, multi_representer)
return
if issubclass(Dumper, BaseDumper):
BaseRepresenter.add_multi_representer(data_type, multi_representer)
elif issubclass(Dumper, SafeDumper):
SafeRepresenter.add_multi_representer(data_type, multi_representer)
elif issubclass(Dumper, Dumper):
Representer.add_multi_representer(data_type, multi_representer)
elif issubclass(Dumper, RoundTripDumper):
RoundTripRepresenter.add_multi_representer(data_type, multi_representer)
else:
raise NotImplementedError
class YAMLObjectMetaclass(type):
"""
The metaclass for YAMLObject.
"""
def __init__(cls, name, bases, kwds):
# type: (Any, Any, Any) -> None
super(YAMLObjectMetaclass, cls).__init__(name, bases, kwds)
if "yaml_tag" in kwds and kwds["yaml_tag"] is not None:
cls.yaml_constructor.add_constructor(
cls.yaml_tag, cls.from_yaml
) # type: ignore
cls.yaml_representer.add_representer(cls, cls.to_yaml) # type: ignore
class YAMLObject(with_metaclass(YAMLObjectMetaclass)): # type: ignore
"""
An object that can dump itself to a YAML stream
and load itself from a YAML stream.
"""
__slots__ = () # no direct instantiation, so allow immutable subclasses
yaml_constructor = Constructor
yaml_representer = Representer
yaml_tag = None # type: Any
yaml_flow_style = None # type: Any
@classmethod
def from_yaml(cls, constructor, node):
# type: (Any, Any) -> Any
"""
Convert a representation node to a Python object.
"""
return constructor.construct_yaml_object(node, cls)
@classmethod
def to_yaml(cls, representer, data):
# type: (Any, Any) -> Any
"""
Convert a Python object to a representation node.
"""
return representer.represent_yaml_object(
cls.yaml_tag, data, cls, flow_style=cls.yaml_flow_style
)
srsly-release-v2.5.1/srsly/ruamel_yaml/nodes.py 0000775 0000000 0000000 00000007204 14742310675 0021663 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
import sys
from .compat import string_types
if False: # MYPY
from typing import Dict, Any, Text # NOQA
class Node(object):
__slots__ = 'tag', 'value', 'start_mark', 'end_mark', 'comment', 'anchor'
def __init__(self, tag, value, start_mark, end_mark, comment=None, anchor=None):
# type: (Any, Any, Any, Any, Any, Any) -> None
self.tag = tag
self.value = value
self.start_mark = start_mark
self.end_mark = end_mark
self.comment = comment
self.anchor = anchor
def __repr__(self):
# type: () -> str
value = self.value
# if isinstance(value, list):
# if len(value) == 0:
# value = ''
# elif len(value) == 1:
# value = '<1 item>'
# else:
# value = '<%d items>' % len(value)
# else:
# if len(value) > 75:
# value = repr(value[:70]+u' ... ')
# else:
# value = repr(value)
value = repr(value)
return '%s(tag=%r, value=%s)' % (self.__class__.__name__, self.tag, value)
def dump(self, indent=0):
# type: (int) -> None
if isinstance(self.value, string_types):
sys.stdout.write(
'{}{}(tag={!r}, value={!r})\n'.format(
' ' * indent, self.__class__.__name__, self.tag, self.value
)
)
if self.comment:
sys.stdout.write(' {}comment: {})\n'.format(' ' * indent, self.comment))
return
sys.stdout.write(
'{}{}(tag={!r})\n'.format(' ' * indent, self.__class__.__name__, self.tag)
)
if self.comment:
sys.stdout.write(' {}comment: {})\n'.format(' ' * indent, self.comment))
for v in self.value:
if isinstance(v, tuple):
for v1 in v:
v1.dump(indent + 1)
elif isinstance(v, Node):
v.dump(indent + 1)
else:
sys.stdout.write('Node value type? {}\n'.format(type(v)))
class ScalarNode(Node):
"""
styles:
? -> set() ? key, no value
" -> double quoted
' -> single quoted
| -> literal style
> -> folding style
"""
__slots__ = ('style',)
id = 'scalar'
def __init__(
self, tag, value, start_mark=None, end_mark=None, style=None, comment=None, anchor=None
):
# type: (Any, Any, Any, Any, Any, Any, Any) -> None
Node.__init__(self, tag, value, start_mark, end_mark, comment=comment, anchor=anchor)
self.style = style
class CollectionNode(Node):
__slots__ = ('flow_style',)
def __init__(
self,
tag,
value,
start_mark=None,
end_mark=None,
flow_style=None,
comment=None,
anchor=None,
):
# type: (Any, Any, Any, Any, Any, Any, Any) -> None
Node.__init__(self, tag, value, start_mark, end_mark, comment=comment)
self.flow_style = flow_style
self.anchor = anchor
class SequenceNode(CollectionNode):
__slots__ = ()
id = 'sequence'
class MappingNode(CollectionNode):
__slots__ = ('merge',)
id = 'mapping'
def __init__(
self,
tag,
value,
start_mark=None,
end_mark=None,
flow_style=None,
comment=None,
anchor=None,
):
# type: (Any, Any, Any, Any, Any, Any, Any) -> None
CollectionNode.__init__(
self, tag, value, start_mark, end_mark, flow_style, comment, anchor
)
self.merge = None
srsly-release-v2.5.1/srsly/ruamel_yaml/parser.py 0000775 0000000 0000000 00000102032 14742310675 0022042 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
# The following YAML grammar is LL(1) and is parsed by a recursive descent
# parser.
#
# stream ::= STREAM-START implicit_document? explicit_document*
# STREAM-END
# implicit_document ::= block_node DOCUMENT-END*
# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
# block_node_or_indentless_sequence ::=
# ALIAS
# | properties (block_content |
# indentless_block_sequence)?
# | block_content
# | indentless_block_sequence
# block_node ::= ALIAS
# | properties block_content?
# | block_content
# flow_node ::= ALIAS
# | properties flow_content?
# | flow_content
# properties ::= TAG ANCHOR? | ANCHOR TAG?
# block_content ::= block_collection | flow_collection | SCALAR
# flow_content ::= flow_collection | SCALAR
# block_collection ::= block_sequence | block_mapping
# flow_collection ::= flow_sequence | flow_mapping
# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)*
# BLOCK-END
# indentless_sequence ::= (BLOCK-ENTRY block_node?)+
# block_mapping ::= BLOCK-MAPPING_START
# ((KEY block_node_or_indentless_sequence?)?
# (VALUE block_node_or_indentless_sequence?)?)*
# BLOCK-END
# flow_sequence ::= FLOW-SEQUENCE-START
# (flow_sequence_entry FLOW-ENTRY)*
# flow_sequence_entry?
# FLOW-SEQUENCE-END
# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
# flow_mapping ::= FLOW-MAPPING-START
# (flow_mapping_entry FLOW-ENTRY)*
# flow_mapping_entry?
# FLOW-MAPPING-END
# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
#
# FIRST sets:
#
# stream: { STREAM-START }
# explicit_document: { DIRECTIVE DOCUMENT-START }
# implicit_document: FIRST(block_node)
# block_node: { ALIAS TAG ANCHOR SCALAR BLOCK-SEQUENCE-START
# BLOCK-MAPPING-START FLOW-SEQUENCE-START FLOW-MAPPING-START }
# flow_node: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START FLOW-MAPPING-START }
# block_content: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START
# FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
# flow_content: { FLOW-SEQUENCE-START FLOW-MAPPING-START SCALAR }
# block_collection: { BLOCK-SEQUENCE-START BLOCK-MAPPING-START }
# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
# block_sequence: { BLOCK-SEQUENCE-START }
# block_mapping: { BLOCK-MAPPING-START }
# block_node_or_indentless_sequence: { ALIAS ANCHOR TAG SCALAR
# BLOCK-SEQUENCE-START BLOCK-MAPPING-START FLOW-SEQUENCE-START
# FLOW-MAPPING-START BLOCK-ENTRY }
# indentless_sequence: { ENTRY }
# flow_collection: { FLOW-SEQUENCE-START FLOW-MAPPING-START }
# flow_sequence: { FLOW-SEQUENCE-START }
# flow_mapping: { FLOW-MAPPING-START }
# flow_sequence_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START
# FLOW-MAPPING-START KEY }
# flow_mapping_entry: { ALIAS ANCHOR TAG SCALAR FLOW-SEQUENCE-START
# FLOW-MAPPING-START KEY }
# need to have full path with import, as pkg_resources tries to load parser.py in __init__.py
# only to not do anything with the package afterwards
# and for Jython too
from .error import MarkedYAMLError
from .tokens import * # NOQA
from .events import * # NOQA
from .scanner import Scanner, RoundTripScanner, ScannerError # NOQA
from .compat import utf8, nprint, nprintf # NOQA
if False: # MYPY
from typing import Any, Dict, Optional, List # NOQA
__all__ = ["Parser", "RoundTripParser", "ParserError"]
class ParserError(MarkedYAMLError):
pass
class Parser(object):
# Since writing a recursive-descendant parser is a straightforward task, we
# do not give many comments here.
DEFAULT_TAGS = {u"!": u"!", u"!!": u"tag:yaml.org,2002:"}
def __init__(self, loader):
# type: (Any) -> None
self.loader = loader
if self.loader is not None and getattr(self.loader, "_parser", None) is None:
self.loader._parser = self
self.reset_parser()
def reset_parser(self):
# type: () -> None
# Reset the state attributes (to clear self-references)
self.current_event = None
self.tag_handles = {} # type: Dict[Any, Any]
self.states = [] # type: List[Any]
self.marks = [] # type: List[Any]
self.state = self.parse_stream_start # type: Any
def dispose(self):
# type: () -> None
self.reset_parser()
@property
def scanner(self):
# type: () -> Any
if hasattr(self.loader, "typ"):
return self.loader.scanner
return self.loader._scanner
@property
def resolver(self):
# type: () -> Any
if hasattr(self.loader, "typ"):
return self.loader.resolver
return self.loader._resolver
def check_event(self, *choices):
# type: (Any) -> bool
# Check the type of the next event.
if self.current_event is None:
if self.state:
self.current_event = self.state()
if self.current_event is not None:
if not choices:
return True
for choice in choices:
if isinstance(self.current_event, choice):
return True
return False
def peek_event(self):
# type: () -> Any
# Get the next event.
if self.current_event is None:
if self.state:
self.current_event = self.state()
return self.current_event
def get_event(self):
# type: () -> Any
# Get the next event and proceed further.
if self.current_event is None:
if self.state:
self.current_event = self.state()
value = self.current_event
self.current_event = None
return value
# stream ::= STREAM-START implicit_document? explicit_document*
# STREAM-END
# implicit_document ::= block_node DOCUMENT-END*
# explicit_document ::= DIRECTIVE* DOCUMENT-START block_node? DOCUMENT-END*
def parse_stream_start(self):
# type: () -> Any
# Parse the stream start.
token = self.scanner.get_token()
token.move_comment(self.scanner.peek_token())
event = StreamStartEvent(
token.start_mark, token.end_mark, encoding=token.encoding
)
# Prepare the next state.
self.state = self.parse_implicit_document_start
return event
def parse_implicit_document_start(self):
# type: () -> Any
# Parse an implicit document.
if not self.scanner.check_token(
DirectiveToken, DocumentStartToken, StreamEndToken
):
self.tag_handles = self.DEFAULT_TAGS
token = self.scanner.peek_token()
start_mark = end_mark = token.start_mark
event = DocumentStartEvent(start_mark, end_mark, explicit=False)
# Prepare the next state.
self.states.append(self.parse_document_end)
self.state = self.parse_block_node
return event
else:
return self.parse_document_start()
def parse_document_start(self):
# type: () -> Any
# Parse any extra document end indicators.
while self.scanner.check_token(DocumentEndToken):
self.scanner.get_token()
# Parse an explicit document.
if not self.scanner.check_token(StreamEndToken):
token = self.scanner.peek_token()
start_mark = token.start_mark
version, tags = self.process_directives()
if not self.scanner.check_token(DocumentStartToken):
raise ParserError(
None,
None,
"expected '', but found %r"
% self.scanner.peek_token().id,
self.scanner.peek_token().start_mark,
)
token = self.scanner.get_token()
end_mark = token.end_mark
# if self.loader is not None and \
# end_mark.line != self.scanner.peek_token().start_mark.line:
# self.loader.scalar_after_indicator = False
event = DocumentStartEvent(
start_mark, end_mark, explicit=True, version=version, tags=tags
) # type: Any
self.states.append(self.parse_document_end)
self.state = self.parse_document_content
else:
# Parse the end of the stream.
token = self.scanner.get_token()
event = StreamEndEvent(
token.start_mark, token.end_mark, comment=token.comment
)
assert not self.states
assert not self.marks
self.state = None
return event
def parse_document_end(self):
# type: () -> Any
# Parse the document end.
token = self.scanner.peek_token()
start_mark = end_mark = token.start_mark
explicit = False
if self.scanner.check_token(DocumentEndToken):
token = self.scanner.get_token()
end_mark = token.end_mark
explicit = True
event = DocumentEndEvent(start_mark, end_mark, explicit=explicit)
# Prepare the next state.
if self.resolver.processing_version == (1, 1):
self.state = self.parse_document_start
else:
self.state = self.parse_implicit_document_start
return event
def parse_document_content(self):
# type: () -> Any
if self.scanner.check_token(
DirectiveToken, DocumentStartToken, DocumentEndToken, StreamEndToken
):
event = self.process_empty_scalar(self.scanner.peek_token().start_mark)
self.state = self.states.pop()
return event
else:
return self.parse_block_node()
def process_directives(self):
# type: () -> Any
yaml_version = None
self.tag_handles = {}
while self.scanner.check_token(DirectiveToken):
token = self.scanner.get_token()
if token.name == u"YAML":
if yaml_version is not None:
raise ParserError(
None, None, "found duplicate YAML directive", token.start_mark
)
major, minor = token.value
if major != 1:
raise ParserError(
None,
None,
"found incompatible YAML document (version 1.* is " "required)",
token.start_mark,
)
yaml_version = token.value
elif token.name == u"TAG":
handle, prefix = token.value
if handle in self.tag_handles:
raise ParserError(
None,
None,
"duplicate tag handle %r" % utf8(handle),
token.start_mark,
)
self.tag_handles[handle] = prefix
if bool(self.tag_handles):
value = yaml_version, self.tag_handles.copy() # type: Any
else:
value = yaml_version, None
if self.loader is not None and hasattr(self.loader, "tags"):
self.loader.version = yaml_version
if self.loader.tags is None:
self.loader.tags = {}
for k in self.tag_handles:
self.loader.tags[k] = self.tag_handles[k]
for key in self.DEFAULT_TAGS:
if key not in self.tag_handles:
self.tag_handles[key] = self.DEFAULT_TAGS[key]
return value
# block_node_or_indentless_sequence ::= ALIAS
# | properties (block_content | indentless_block_sequence)?
# | block_content
# | indentless_block_sequence
# block_node ::= ALIAS
# | properties block_content?
# | block_content
# flow_node ::= ALIAS
# | properties flow_content?
# | flow_content
# properties ::= TAG ANCHOR? | ANCHOR TAG?
# block_content ::= block_collection | flow_collection | SCALAR
# flow_content ::= flow_collection | SCALAR
# block_collection ::= block_sequence | block_mapping
# flow_collection ::= flow_sequence | flow_mapping
def parse_block_node(self):
# type: () -> Any
return self.parse_node(block=True)
def parse_flow_node(self):
# type: () -> Any
return self.parse_node()
def parse_block_node_or_indentless_sequence(self):
# type: () -> Any
return self.parse_node(block=True, indentless_sequence=True)
def transform_tag(self, handle, suffix):
# type: (Any, Any) -> Any
return self.tag_handles[handle] + suffix
def parse_node(self, block=False, indentless_sequence=False):
# type: (bool, bool) -> Any
if self.scanner.check_token(AliasToken):
token = self.scanner.get_token()
event = AliasEvent(
token.value, token.start_mark, token.end_mark
) # type: Any
self.state = self.states.pop()
return event
anchor = None
tag = None
start_mark = end_mark = tag_mark = None
if self.scanner.check_token(AnchorToken):
token = self.scanner.get_token()
start_mark = token.start_mark
end_mark = token.end_mark
anchor = token.value
if self.scanner.check_token(TagToken):
token = self.scanner.get_token()
tag_mark = token.start_mark
end_mark = token.end_mark
tag = token.value
elif self.scanner.check_token(TagToken):
token = self.scanner.get_token()
start_mark = tag_mark = token.start_mark
end_mark = token.end_mark
tag = token.value
if self.scanner.check_token(AnchorToken):
token = self.scanner.get_token()
start_mark = tag_mark = token.start_mark
end_mark = token.end_mark
anchor = token.value
if tag is not None:
handle, suffix = tag
if handle is not None:
if handle not in self.tag_handles:
raise ParserError(
"while parsing a node",
start_mark,
"found undefined tag handle %r" % utf8(handle),
tag_mark,
)
tag = self.transform_tag(handle, suffix)
else:
tag = suffix
# if tag == u'!':
# raise ParserError("while parsing a node", start_mark,
# "found non-specific tag '!'", tag_mark,
# "Please check 'http://pyyaml.org/wiki/YAMLNonSpecificTag'
# and share your opinion.")
if start_mark is None:
start_mark = end_mark = self.scanner.peek_token().start_mark
event = None
implicit = tag is None or tag == u"!"
if indentless_sequence and self.scanner.check_token(BlockEntryToken):
comment = None
pt = self.scanner.peek_token()
if pt.comment and pt.comment[0]:
comment = [pt.comment[0], []]
pt.comment[0] = None
end_mark = self.scanner.peek_token().end_mark
event = SequenceStartEvent(
anchor,
tag,
implicit,
start_mark,
end_mark,
flow_style=False,
comment=comment,
)
self.state = self.parse_indentless_sequence_entry
return event
if self.scanner.check_token(ScalarToken):
token = self.scanner.get_token()
# self.scanner.peek_token_same_line_comment(token)
end_mark = token.end_mark
if (token.plain and tag is None) or tag == u"!":
implicit = (True, False)
elif tag is None:
implicit = (False, True)
else:
implicit = (False, False)
# nprint('se', token.value, token.comment)
event = ScalarEvent(
anchor,
tag,
implicit,
token.value,
start_mark,
end_mark,
style=token.style,
comment=token.comment,
)
self.state = self.states.pop()
elif self.scanner.check_token(FlowSequenceStartToken):
pt = self.scanner.peek_token()
end_mark = pt.end_mark
event = SequenceStartEvent(
anchor,
tag,
implicit,
start_mark,
end_mark,
flow_style=True,
comment=pt.comment,
)
self.state = self.parse_flow_sequence_first_entry
elif self.scanner.check_token(FlowMappingStartToken):
pt = self.scanner.peek_token()
end_mark = pt.end_mark
event = MappingStartEvent(
anchor,
tag,
implicit,
start_mark,
end_mark,
flow_style=True,
comment=pt.comment,
)
self.state = self.parse_flow_mapping_first_key
elif block and self.scanner.check_token(BlockSequenceStartToken):
end_mark = self.scanner.peek_token().start_mark
# should inserting the comment be dependent on the
# indentation?
pt = self.scanner.peek_token()
comment = pt.comment
# nprint('pt0', type(pt))
if comment is None or comment[1] is None:
comment = pt.split_comment()
# nprint('pt1', comment)
event = SequenceStartEvent(
anchor,
tag,
implicit,
start_mark,
end_mark,
flow_style=False,
comment=comment,
)
self.state = self.parse_block_sequence_first_entry
elif block and self.scanner.check_token(BlockMappingStartToken):
end_mark = self.scanner.peek_token().start_mark
comment = self.scanner.peek_token().comment
event = MappingStartEvent(
anchor,
tag,
implicit,
start_mark,
end_mark,
flow_style=False,
comment=comment,
)
self.state = self.parse_block_mapping_first_key
elif anchor is not None or tag is not None:
# Empty scalars are allowed even if a tag or an anchor is
# specified.
event = ScalarEvent(
anchor, tag, (implicit, False), "", start_mark, end_mark
)
self.state = self.states.pop()
else:
if block:
node = "block"
else:
node = "flow"
token = self.scanner.peek_token()
raise ParserError(
"while parsing a %s node" % node,
start_mark,
"expected the node content, but found %r" % token.id,
token.start_mark,
)
return event
# block_sequence ::= BLOCK-SEQUENCE-START (BLOCK-ENTRY block_node?)*
# BLOCK-END
def parse_block_sequence_first_entry(self):
# type: () -> Any
token = self.scanner.get_token()
# move any comment from start token
# token.move_comment(self.scanner.peek_token())
self.marks.append(token.start_mark)
return self.parse_block_sequence_entry()
def parse_block_sequence_entry(self):
# type: () -> Any
if self.scanner.check_token(BlockEntryToken):
token = self.scanner.get_token()
token.move_comment(self.scanner.peek_token())
if not self.scanner.check_token(BlockEntryToken, BlockEndToken):
self.states.append(self.parse_block_sequence_entry)
return self.parse_block_node()
else:
self.state = self.parse_block_sequence_entry
return self.process_empty_scalar(token.end_mark)
if not self.scanner.check_token(BlockEndToken):
token = self.scanner.peek_token()
raise ParserError(
"while parsing a block collection",
self.marks[-1],
"expected , but found %r" % token.id,
token.start_mark,
)
token = self.scanner.get_token() # BlockEndToken
event = SequenceEndEvent(
token.start_mark, token.end_mark, comment=token.comment
)
self.state = self.states.pop()
self.marks.pop()
return event
# indentless_sequence ::= (BLOCK-ENTRY block_node?)+
# indentless_sequence?
# sequence:
# - entry
# - nested
def parse_indentless_sequence_entry(self):
# type: () -> Any
if self.scanner.check_token(BlockEntryToken):
token = self.scanner.get_token()
token.move_comment(self.scanner.peek_token())
if not self.scanner.check_token(
BlockEntryToken, KeyToken, ValueToken, BlockEndToken
):
self.states.append(self.parse_indentless_sequence_entry)
return self.parse_block_node()
else:
self.state = self.parse_indentless_sequence_entry
return self.process_empty_scalar(token.end_mark)
token = self.scanner.peek_token()
event = SequenceEndEvent(
token.start_mark, token.start_mark, comment=token.comment
)
self.state = self.states.pop()
return event
# block_mapping ::= BLOCK-MAPPING_START
# ((KEY block_node_or_indentless_sequence?)?
# (VALUE block_node_or_indentless_sequence?)?)*
# BLOCK-END
def parse_block_mapping_first_key(self):
# type: () -> Any
token = self.scanner.get_token()
self.marks.append(token.start_mark)
return self.parse_block_mapping_key()
def parse_block_mapping_key(self):
# type: () -> Any
if self.scanner.check_token(KeyToken):
token = self.scanner.get_token()
token.move_comment(self.scanner.peek_token())
if not self.scanner.check_token(KeyToken, ValueToken, BlockEndToken):
self.states.append(self.parse_block_mapping_value)
return self.parse_block_node_or_indentless_sequence()
else:
self.state = self.parse_block_mapping_value
return self.process_empty_scalar(token.end_mark)
if self.resolver.processing_version > (1, 1) and self.scanner.check_token(
ValueToken
):
self.state = self.parse_block_mapping_value
return self.process_empty_scalar(self.scanner.peek_token().start_mark)
if not self.scanner.check_token(BlockEndToken):
token = self.scanner.peek_token()
raise ParserError(
"while parsing a block mapping",
self.marks[-1],
"expected , but found %r" % token.id,
token.start_mark,
)
token = self.scanner.get_token()
token.move_comment(self.scanner.peek_token())
event = MappingEndEvent(token.start_mark, token.end_mark, comment=token.comment)
self.state = self.states.pop()
self.marks.pop()
return event
def parse_block_mapping_value(self):
# type: () -> Any
if self.scanner.check_token(ValueToken):
token = self.scanner.get_token()
# value token might have post comment move it to e.g. block
if self.scanner.check_token(ValueToken):
token.move_comment(self.scanner.peek_token())
else:
if not self.scanner.check_token(KeyToken):
token.move_comment(self.scanner.peek_token(), empty=True)
# else: empty value for this key cannot move token.comment
if not self.scanner.check_token(KeyToken, ValueToken, BlockEndToken):
self.states.append(self.parse_block_mapping_key)
return self.parse_block_node_or_indentless_sequence()
else:
self.state = self.parse_block_mapping_key
comment = token.comment
if comment is None:
token = self.scanner.peek_token()
comment = token.comment
if comment:
token._comment = [None, comment[1]]
comment = [comment[0], None]
return self.process_empty_scalar(token.end_mark, comment=comment)
else:
self.state = self.parse_block_mapping_key
token = self.scanner.peek_token()
return self.process_empty_scalar(token.start_mark)
# flow_sequence ::= FLOW-SEQUENCE-START
# (flow_sequence_entry FLOW-ENTRY)*
# flow_sequence_entry?
# FLOW-SEQUENCE-END
# flow_sequence_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
#
# Note that while production rules for both flow_sequence_entry and
# flow_mapping_entry are equal, their interpretations are different.
# For `flow_sequence_entry`, the part `KEY flow_node? (VALUE flow_node?)?`
# generate an inline mapping (set syntax).
def parse_flow_sequence_first_entry(self):
# type: () -> Any
token = self.scanner.get_token()
self.marks.append(token.start_mark)
return self.parse_flow_sequence_entry(first=True)
def parse_flow_sequence_entry(self, first=False):
# type: (bool) -> Any
if not self.scanner.check_token(FlowSequenceEndToken):
if not first:
if self.scanner.check_token(FlowEntryToken):
self.scanner.get_token()
else:
token = self.scanner.peek_token()
raise ParserError(
"while parsing a flow sequence",
self.marks[-1],
"expected ',' or ']', but got %r" % token.id,
token.start_mark,
)
if self.scanner.check_token(KeyToken):
token = self.scanner.peek_token()
event = MappingStartEvent(
None, None, True, token.start_mark, token.end_mark, flow_style=True
) # type: Any
self.state = self.parse_flow_sequence_entry_mapping_key
return event
elif not self.scanner.check_token(FlowSequenceEndToken):
self.states.append(self.parse_flow_sequence_entry)
return self.parse_flow_node()
token = self.scanner.get_token()
event = SequenceEndEvent(
token.start_mark, token.end_mark, comment=token.comment
)
self.state = self.states.pop()
self.marks.pop()
return event
def parse_flow_sequence_entry_mapping_key(self):
# type: () -> Any
token = self.scanner.get_token()
if not self.scanner.check_token(
ValueToken, FlowEntryToken, FlowSequenceEndToken
):
self.states.append(self.parse_flow_sequence_entry_mapping_value)
return self.parse_flow_node()
else:
self.state = self.parse_flow_sequence_entry_mapping_value
return self.process_empty_scalar(token.end_mark)
def parse_flow_sequence_entry_mapping_value(self):
# type: () -> Any
if self.scanner.check_token(ValueToken):
token = self.scanner.get_token()
if not self.scanner.check_token(FlowEntryToken, FlowSequenceEndToken):
self.states.append(self.parse_flow_sequence_entry_mapping_end)
return self.parse_flow_node()
else:
self.state = self.parse_flow_sequence_entry_mapping_end
return self.process_empty_scalar(token.end_mark)
else:
self.state = self.parse_flow_sequence_entry_mapping_end
token = self.scanner.peek_token()
return self.process_empty_scalar(token.start_mark)
def parse_flow_sequence_entry_mapping_end(self):
# type: () -> Any
self.state = self.parse_flow_sequence_entry
token = self.scanner.peek_token()
return MappingEndEvent(token.start_mark, token.start_mark)
# flow_mapping ::= FLOW-MAPPING-START
# (flow_mapping_entry FLOW-ENTRY)*
# flow_mapping_entry?
# FLOW-MAPPING-END
# flow_mapping_entry ::= flow_node | KEY flow_node? (VALUE flow_node?)?
def parse_flow_mapping_first_key(self):
# type: () -> Any
token = self.scanner.get_token()
self.marks.append(token.start_mark)
return self.parse_flow_mapping_key(first=True)
def parse_flow_mapping_key(self, first=False):
# type: (Any) -> Any
if not self.scanner.check_token(FlowMappingEndToken):
if not first:
if self.scanner.check_token(FlowEntryToken):
self.scanner.get_token()
else:
token = self.scanner.peek_token()
raise ParserError(
"while parsing a flow mapping",
self.marks[-1],
"expected ',' or '}', but got %r" % token.id,
token.start_mark,
)
if self.scanner.check_token(KeyToken):
token = self.scanner.get_token()
if not self.scanner.check_token(
ValueToken, FlowEntryToken, FlowMappingEndToken
):
self.states.append(self.parse_flow_mapping_value)
return self.parse_flow_node()
else:
self.state = self.parse_flow_mapping_value
return self.process_empty_scalar(token.end_mark)
elif self.resolver.processing_version > (1, 1) and self.scanner.check_token(
ValueToken
):
self.state = self.parse_flow_mapping_value
return self.process_empty_scalar(self.scanner.peek_token().end_mark)
elif not self.scanner.check_token(FlowMappingEndToken):
self.states.append(self.parse_flow_mapping_empty_value)
return self.parse_flow_node()
token = self.scanner.get_token()
event = MappingEndEvent(token.start_mark, token.end_mark, comment=token.comment)
self.state = self.states.pop()
self.marks.pop()
return event
def parse_flow_mapping_value(self):
# type: () -> Any
if self.scanner.check_token(ValueToken):
token = self.scanner.get_token()
if not self.scanner.check_token(FlowEntryToken, FlowMappingEndToken):
self.states.append(self.parse_flow_mapping_key)
return self.parse_flow_node()
else:
self.state = self.parse_flow_mapping_key
return self.process_empty_scalar(token.end_mark)
else:
self.state = self.parse_flow_mapping_key
token = self.scanner.peek_token()
return self.process_empty_scalar(token.start_mark)
def parse_flow_mapping_empty_value(self):
# type: () -> Any
self.state = self.parse_flow_mapping_key
return self.process_empty_scalar(self.scanner.peek_token().start_mark)
def process_empty_scalar(self, mark, comment=None):
# type: (Any, Any) -> Any
return ScalarEvent(None, None, (True, False), "", mark, mark, comment=comment)
class RoundTripParser(Parser):
"""roundtrip is a safe loader, that wants to see the unmangled tag"""
def transform_tag(self, handle, suffix):
# type: (Any, Any) -> Any
# return self.tag_handles[handle]+suffix
if handle == "!!" and suffix in (
u"null",
u"bool",
u"int",
u"float",
u"binary",
u"timestamp",
u"omap",
u"pairs",
u"set",
u"str",
u"seq",
u"map",
):
return Parser.transform_tag(self, handle, suffix)
return handle + suffix
srsly-release-v2.5.1/srsly/ruamel_yaml/py.typed 0000775 0000000 0000000 00000000000 14742310675 0021663 0 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/ruamel_yaml/reader.py 0000775 0000000 0000000 00000025651 14742310675 0022023 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
# This module contains abstractions for the input stream. You don't have to
# looks further, there are no pretty code.
#
# We define two classes here.
#
# Mark(source, line, column)
# It's just a record and its only use is producing nice error messages.
# Parser does not use it for any other purposes.
#
# Reader(source, data)
# Reader determines the encoding of `data` and converts it to unicode.
# Reader provides the following methods and attributes:
# reader.peek(length=1) - return the next `length` characters
# reader.forward(length=1) - move the current position to `length`
# characters.
# reader.index - the number of the current character.
# reader.line, stream.column - the line and the column of the current
# character.
import codecs
from .error import YAMLError, FileMark, StringMark, YAMLStreamError
from .compat import text_type, binary_type, PY3, UNICODE_SIZE
from .util import RegExp
if False: # MYPY
from typing import Any, Dict, Optional, List, Union, Text, Tuple, Optional # NOQA
# from srsly.ruamel_yaml.compat import StreamTextType # NOQA
__all__ = ["Reader", "ReaderError"]
class ReaderError(YAMLError):
def __init__(self, name, position, character, encoding, reason):
# type: (Any, Any, Any, Any, Any) -> None
self.name = name
self.character = character
self.position = position
self.encoding = encoding
self.reason = reason
def __str__(self):
# type: () -> str
if isinstance(self.character, binary_type):
return (
"'%s' codec can't decode byte #x%02x: %s\n"
' in "%s", position %d'
% (
self.encoding,
ord(self.character),
self.reason,
self.name,
self.position,
)
)
else:
return "unacceptable character #x%04x: %s\n" ' in "%s", position %d' % (
self.character,
self.reason,
self.name,
self.position,
)
class Reader(object):
# Reader:
# - determines the data encoding and converts it to a unicode string,
# - checks if characters are in allowed range,
# - adds '\0' to the end.
# Reader accepts
# - a `str` object (PY2) / a `bytes` object (PY3),
# - a `unicode` object (PY2) / a `str` object (PY3),
# - a file-like object with its `read` method returning `str`,
# - a file-like object with its `read` method returning `unicode`.
# Yeah, it's ugly and slow.
def __init__(self, stream, loader=None):
# type: (Any, Any) -> None
self.loader = loader
if self.loader is not None and getattr(self.loader, "_reader", None) is None:
self.loader._reader = self
self.reset_reader()
self.stream = stream # type: Any # as .read is called
def reset_reader(self):
# type: () -> None
self.name = None # type: Any
self.stream_pointer = 0
self.eof = True
self.buffer = ""
self.pointer = 0
self.raw_buffer = None # type: Any
self.raw_decode = None
self.encoding = None # type: Optional[Text]
self.index = 0
self.line = 0
self.column = 0
@property
def stream(self):
# type: () -> Any
try:
return self._stream
except AttributeError:
raise YAMLStreamError("input stream needs to specified")
@stream.setter
def stream(self, val):
# type: (Any) -> None
if val is None:
return
self._stream = None
if isinstance(val, text_type):
self.name = ""
self.check_printable(val)
self.buffer = val + u"\0" # type: ignore
elif isinstance(val, binary_type):
self.name = ""
self.raw_buffer = val
self.determine_encoding()
else:
if not hasattr(val, "read"):
raise YAMLStreamError("stream argument needs to have a read() method")
self._stream = val
self.name = getattr(self.stream, "name", "")
self.eof = False
self.raw_buffer = None
self.determine_encoding()
def peek(self, index=0):
# type: (int) -> Text
try:
return self.buffer[self.pointer + index]
except IndexError:
self.update(index + 1)
return self.buffer[self.pointer + index]
def prefix(self, length=1):
# type: (int) -> Any
if self.pointer + length >= len(self.buffer):
self.update(length)
return self.buffer[self.pointer : self.pointer + length]
def forward_1_1(self, length=1):
# type: (int) -> None
if self.pointer + length + 1 >= len(self.buffer):
self.update(length + 1)
while length != 0:
ch = self.buffer[self.pointer]
self.pointer += 1
self.index += 1
if ch in u"\n\x85\u2028\u2029" or (
ch == u"\r" and self.buffer[self.pointer] != u"\n"
):
self.line += 1
self.column = 0
elif ch != u"\uFEFF":
self.column += 1
length -= 1
def forward(self, length=1):
# type: (int) -> None
if self.pointer + length + 1 >= len(self.buffer):
self.update(length + 1)
while length != 0:
ch = self.buffer[self.pointer]
self.pointer += 1
self.index += 1
if ch == u"\n" or (ch == u"\r" and self.buffer[self.pointer] != u"\n"):
self.line += 1
self.column = 0
elif ch != u"\uFEFF":
self.column += 1
length -= 1
def get_mark(self):
# type: () -> Any
if self.stream is None:
return StringMark(
self.name, self.index, self.line, self.column, self.buffer, self.pointer
)
else:
return FileMark(self.name, self.index, self.line, self.column)
def determine_encoding(self):
# type: () -> None
while not self.eof and (self.raw_buffer is None or len(self.raw_buffer) < 2):
self.update_raw()
if isinstance(self.raw_buffer, binary_type):
if self.raw_buffer.startswith(codecs.BOM_UTF16_LE):
self.raw_decode = codecs.utf_16_le_decode # type: ignore
self.encoding = "utf-16-le"
elif self.raw_buffer.startswith(codecs.BOM_UTF16_BE):
self.raw_decode = codecs.utf_16_be_decode # type: ignore
self.encoding = "utf-16-be"
else:
self.raw_decode = codecs.utf_8_decode # type: ignore
self.encoding = "utf-8"
self.update(1)
if UNICODE_SIZE == 2:
NON_PRINTABLE = RegExp(
u"[^\x09\x0A\x0D\x20-\x7E\x85" u"\xA0-\uD7FF" u"\uE000-\uFFFD" u"]"
)
else:
NON_PRINTABLE = RegExp(
u"[^\x09\x0A\x0D\x20-\x7E\x85"
u"\xA0-\uD7FF"
u"\uE000-\uFFFD"
u"\U00010000-\U0010FFFF"
u"]"
)
_printable_ascii = ("\x09\x0A\x0D" + "".join(map(chr, range(0x20, 0x7F)))).encode(
"ascii"
)
@classmethod
def _get_non_printable_ascii(cls, data): # type: ignore
# type: (Text, bytes) -> Optional[Tuple[int, Text]]
ascii_bytes = data.encode("ascii")
non_printables = ascii_bytes.translate(
None, cls._printable_ascii
) # type: ignore
if not non_printables:
return None
non_printable = non_printables[:1]
return ascii_bytes.index(non_printable), non_printable.decode("ascii")
@classmethod
def _get_non_printable_regex(cls, data):
# type: (Text) -> Optional[Tuple[int, Text]]
match = cls.NON_PRINTABLE.search(data)
if not bool(match):
return None
return match.start(), match.group()
@classmethod
def _get_non_printable(cls, data):
# type: (Text) -> Optional[Tuple[int, Text]]
try:
return cls._get_non_printable_ascii(data) # type: ignore
except UnicodeEncodeError:
return cls._get_non_printable_regex(data)
def check_printable(self, data):
# type: (Any) -> None
non_printable_match = self._get_non_printable(data)
if non_printable_match is not None:
start, character = non_printable_match
position = self.index + (len(self.buffer) - self.pointer) + start
raise ReaderError(
self.name,
position,
ord(character),
"unicode",
"special characters are not allowed",
)
def update(self, length):
# type: (int) -> None
if self.raw_buffer is None:
return
self.buffer = self.buffer[self.pointer :]
self.pointer = 0
while len(self.buffer) < length:
if not self.eof:
self.update_raw()
if self.raw_decode is not None:
try:
data, converted = self.raw_decode(
self.raw_buffer, "strict", self.eof
)
except UnicodeDecodeError as exc:
if PY3:
character = self.raw_buffer[exc.start]
else:
character = exc.object[exc.start]
if self.stream is not None:
position = (
self.stream_pointer - len(self.raw_buffer) + exc.start
)
elif self.stream is not None:
position = (
self.stream_pointer - len(self.raw_buffer) + exc.start
)
else:
position = exc.start
raise ReaderError(
self.name, position, character, exc.encoding, exc.reason
)
else:
data = self.raw_buffer
converted = len(data)
self.check_printable(data)
self.buffer += data
self.raw_buffer = self.raw_buffer[converted:]
if self.eof:
self.buffer += "\0"
self.raw_buffer = None
break
def update_raw(self, size=None):
# type: (Optional[int]) -> None
if size is None:
size = 4096 if PY3 else 1024
data = self.stream.read(size)
if self.raw_buffer is None:
self.raw_buffer = data
else:
self.raw_buffer += data
self.stream_pointer += len(data)
if not data:
self.eof = True
# try:
# import psyco
# psyco.bind(Reader)
# except ImportError:
# pass
srsly-release-v2.5.1/srsly/ruamel_yaml/representer.py 0000775 0000000 0000000 00000140051 14742310675 0023107 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division
from .error import * # NOQA
from .nodes import * # NOQA
from .compat import text_type, binary_type, to_unicode, PY2, PY3
from .compat import ordereddict # type: ignore
from .compat import nprint, nprintf # NOQA
from .scalarstring import (
LiteralScalarString,
FoldedScalarString,
SingleQuotedScalarString,
DoubleQuotedScalarString,
PlainScalarString,
)
from .scalarint import ScalarInt, BinaryInt, OctalInt, HexInt, HexCapsInt
from .scalarfloat import ScalarFloat
from .scalarbool import ScalarBoolean
from .timestamp import TimeStamp
import datetime
import sys
import types
if PY3:
import copyreg
import base64
else:
import copy_reg as copyreg # type: ignore
if False: # MYPY
from typing import Dict, List, Any, Union, Text, Optional # NOQA
# fmt: off
__all__ = ['BaseRepresenter', 'SafeRepresenter', 'Representer',
'RepresenterError', 'RoundTripRepresenter']
# fmt: on
class RepresenterError(YAMLError):
pass
if PY2:
def get_classobj_bases(cls):
# type: (Any) -> Any
bases = [cls]
for base in cls.__bases__:
bases.extend(get_classobj_bases(base))
return bases
class BaseRepresenter(object):
yaml_representers = {} # type: Dict[Any, Any]
yaml_multi_representers = {} # type: Dict[Any, Any]
def __init__(self, default_style=None, default_flow_style=None, dumper=None):
# type: (Any, Any, Any, Any) -> None
self.dumper = dumper
if self.dumper is not None:
self.dumper._representer = self
self.default_style = default_style
self.default_flow_style = default_flow_style
self.represented_objects = {} # type: Dict[Any, Any]
self.object_keeper = [] # type: List[Any]
self.alias_key = None # type: Optional[int]
self.sort_base_mapping_type_on_output = True
@property
def serializer(self):
# type: () -> Any
try:
if hasattr(self.dumper, "typ"):
return self.dumper.serializer
return self.dumper._serializer
except AttributeError:
return self # cyaml
def represent(self, data):
# type: (Any) -> None
node = self.represent_data(data)
self.serializer.serialize(node)
self.represented_objects = {}
self.object_keeper = []
self.alias_key = None
def represent_data(self, data):
# type: (Any) -> Any
if self.ignore_aliases(data):
self.alias_key = None
else:
self.alias_key = id(data)
if self.alias_key is not None:
if self.alias_key in self.represented_objects:
node = self.represented_objects[self.alias_key]
# if node is None:
# raise RepresenterError(
# "recursive objects are not allowed: %r" % data)
return node
# self.represented_objects[alias_key] = None
self.object_keeper.append(data)
data_types = type(data).__mro__
if PY2:
# if type(data) is types.InstanceType:
if isinstance(data, types.InstanceType):
data_types = get_classobj_bases(data.__class__) + list(data_types)
if data_types[0] in self.yaml_representers:
node = self.yaml_representers[data_types[0]](self, data)
else:
for data_type in data_types:
if data_type in self.yaml_multi_representers:
node = self.yaml_multi_representers[data_type](self, data)
break
else:
if None in self.yaml_multi_representers:
node = self.yaml_multi_representers[None](self, data)
elif None in self.yaml_representers:
node = self.yaml_representers[None](self, data)
else:
node = ScalarNode(None, text_type(data))
# if alias_key is not None:
# self.represented_objects[alias_key] = node
return node
def represent_key(self, data):
# type: (Any) -> Any
"""
David Fraser: Extract a method to represent keys in mappings, so that
a subclass can choose not to quote them (for example)
used in represent_mapping
https://bitbucket.org/davidfraser/pyyaml/commits/d81df6eb95f20cac4a79eed95ae553b5c6f77b8c
"""
return self.represent_data(data)
@classmethod
def add_representer(cls, data_type, representer):
# type: (Any, Any) -> None
if "yaml_representers" not in cls.__dict__:
cls.yaml_representers = cls.yaml_representers.copy()
cls.yaml_representers[data_type] = representer
@classmethod
def add_multi_representer(cls, data_type, representer):
# type: (Any, Any) -> None
if "yaml_multi_representers" not in cls.__dict__:
cls.yaml_multi_representers = cls.yaml_multi_representers.copy()
cls.yaml_multi_representers[data_type] = representer
def represent_scalar(self, tag, value, style=None, anchor=None):
# type: (Any, Any, Any, Any) -> Any
if style is None:
style = self.default_style
comment = None
if style and style[0] in "|>":
comment = getattr(value, "comment", None)
if comment:
comment = [None, [comment]]
node = ScalarNode(tag, value, style=style, comment=comment, anchor=anchor)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
return node
def represent_sequence(self, tag, sequence, flow_style=None):
# type: (Any, Any, Any) -> Any
value = [] # type: List[Any]
node = SequenceNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
for item in sequence:
node_item = self.represent_data(item)
if not (isinstance(node_item, ScalarNode) and not node_item.style):
best_style = False
value.append(node_item)
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def represent_omap(self, tag, omap, flow_style=None):
# type: (Any, Any, Any) -> Any
value = [] # type: List[Any]
node = SequenceNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
for item_key in omap:
item_val = omap[item_key]
node_item = self.represent_data({item_key: item_val})
# if not (isinstance(node_item, ScalarNode) \
# and not node_item.style):
# best_style = False
value.append(node_item)
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def represent_mapping(self, tag, mapping, flow_style=None):
# type: (Any, Any, Any) -> Any
value = [] # type: List[Any]
node = MappingNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
if hasattr(mapping, "items"):
mapping = list(mapping.items())
if self.sort_base_mapping_type_on_output:
try:
mapping = sorted(mapping)
except TypeError:
pass
for item_key, item_value in mapping:
node_key = self.represent_key(item_key)
node_value = self.represent_data(item_value)
if not (isinstance(node_key, ScalarNode) and not node_key.style):
best_style = False
if not (isinstance(node_value, ScalarNode) and not node_value.style):
best_style = False
value.append((node_key, node_value))
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def ignore_aliases(self, data):
# type: (Any) -> bool
return False
class SafeRepresenter(BaseRepresenter):
def ignore_aliases(self, data):
# type: (Any) -> bool
# https://docs.python.org/3/reference/expressions.html#parenthesized-forms :
# "i.e. two occurrences of the empty tuple may or may not yield the same object"
# so "data is ()" should not be used
if data is None or (isinstance(data, tuple) and data == ()):
return True
if isinstance(data, (binary_type, text_type, bool, int, float)):
return True
return False
def represent_none(self, data):
# type: (Any) -> Any
return self.represent_scalar(u"tag:yaml.org,2002:null", u"null")
if PY3:
def represent_str(self, data):
# type: (Any) -> Any
return self.represent_scalar(u"tag:yaml.org,2002:str", data)
def represent_binary(self, data):
# type: (Any) -> Any
if hasattr(base64, "encodebytes"):
data = base64.encodebytes(data).decode("ascii")
else:
data = base64.encodestring(data).decode("ascii")
return self.represent_scalar(u"tag:yaml.org,2002:binary", data, style="|")
else:
def represent_str(self, data):
# type: (Any) -> Any
tag = None
style = None
try:
data = unicode(data, "ascii")
tag = u"tag:yaml.org,2002:str"
except UnicodeDecodeError:
try:
data = unicode(data, "utf-8")
tag = u"tag:yaml.org,2002:str"
except UnicodeDecodeError:
data = data.encode("base64")
tag = u"tag:yaml.org,2002:binary"
style = "|"
return self.represent_scalar(tag, data, style=style)
def represent_unicode(self, data):
# type: (Any) -> Any
return self.represent_scalar(u"tag:yaml.org,2002:str", data)
def represent_bool(self, data, anchor=None):
# type: (Any, Optional[Any]) -> Any
try:
value = self.dumper.boolean_representation[bool(data)]
except AttributeError:
if data:
value = u"true"
else:
value = u"false"
return self.represent_scalar(u"tag:yaml.org,2002:bool", value, anchor=anchor)
def represent_int(self, data):
# type: (Any) -> Any
return self.represent_scalar(u"tag:yaml.org,2002:int", text_type(data))
if PY2:
def represent_long(self, data):
# type: (Any) -> Any
return self.represent_scalar(u"tag:yaml.org,2002:int", text_type(data))
inf_value = 1e300
while repr(inf_value) != repr(inf_value * inf_value):
inf_value *= inf_value
def represent_float(self, data):
# type: (Any) -> Any
if data != data or (data == 0.0 and data == 1.0):
value = u".nan"
elif data == self.inf_value:
value = u".inf"
elif data == -self.inf_value:
value = u"-.inf"
else:
value = to_unicode(repr(data)).lower()
if getattr(self.serializer, "use_version", None) == (1, 1):
if u"." not in value and u"e" in value:
# Note that in some cases `repr(data)` represents a float number
# without the decimal parts. For instance:
# >>> repr(1e17)
# '1e17'
# Unfortunately, this is not a valid float representation according
# to the definition of the `!!float` tag in YAML 1.1. We fix
# this by adding '.0' before the 'e' symbol.
value = value.replace(u"e", u".0e", 1)
return self.represent_scalar(u"tag:yaml.org,2002:float", value)
def represent_list(self, data):
# type: (Any) -> Any
# pairs = (len(data) > 0 and isinstance(data, list))
# if pairs:
# for item in data:
# if not isinstance(item, tuple) or len(item) != 2:
# pairs = False
# break
# if not pairs:
return self.represent_sequence(u"tag:yaml.org,2002:seq", data)
# value = []
# for item_key, item_value in data:
# value.append(self.represent_mapping(u'tag:yaml.org,2002:map',
# [(item_key, item_value)]))
# return SequenceNode(u'tag:yaml.org,2002:pairs', value)
def represent_dict(self, data):
# type: (Any) -> Any
return self.represent_mapping(u"tag:yaml.org,2002:map", data)
def represent_ordereddict(self, data):
# type: (Any) -> Any
return self.represent_omap(u"tag:yaml.org,2002:omap", data)
def represent_set(self, data):
# type: (Any) -> Any
value = {} # type: Dict[Any, None]
for key in data:
value[key] = None
return self.represent_mapping(u"tag:yaml.org,2002:set", value)
def represent_date(self, data):
# type: (Any) -> Any
value = to_unicode(data.isoformat())
return self.represent_scalar(u"tag:yaml.org,2002:timestamp", value)
def represent_datetime(self, data):
# type: (Any) -> Any
value = to_unicode(data.isoformat(" "))
return self.represent_scalar(u"tag:yaml.org,2002:timestamp", value)
def represent_yaml_object(self, tag, data, cls, flow_style=None):
# type: (Any, Any, Any, Any) -> Any
if hasattr(data, "__getstate__"):
state = data.__getstate__()
else:
state = data.__dict__.copy()
return self.represent_mapping(tag, state, flow_style=flow_style)
def represent_undefined(self, data):
# type: (Any) -> None
raise RepresenterError("cannot represent an object: %s" % (data,))
SafeRepresenter.add_representer(type(None), SafeRepresenter.represent_none)
SafeRepresenter.add_representer(str, SafeRepresenter.represent_str)
if PY2:
SafeRepresenter.add_representer(unicode, SafeRepresenter.represent_unicode)
else:
SafeRepresenter.add_representer(bytes, SafeRepresenter.represent_binary)
SafeRepresenter.add_representer(bool, SafeRepresenter.represent_bool)
SafeRepresenter.add_representer(int, SafeRepresenter.represent_int)
if PY2:
SafeRepresenter.add_representer(long, SafeRepresenter.represent_long)
SafeRepresenter.add_representer(float, SafeRepresenter.represent_float)
SafeRepresenter.add_representer(list, SafeRepresenter.represent_list)
SafeRepresenter.add_representer(tuple, SafeRepresenter.represent_list)
SafeRepresenter.add_representer(dict, SafeRepresenter.represent_dict)
SafeRepresenter.add_representer(set, SafeRepresenter.represent_set)
SafeRepresenter.add_representer(ordereddict, SafeRepresenter.represent_ordereddict)
if sys.version_info >= (2, 7):
import collections
SafeRepresenter.add_representer(
collections.OrderedDict, SafeRepresenter.represent_ordereddict
)
SafeRepresenter.add_representer(datetime.date, SafeRepresenter.represent_date)
SafeRepresenter.add_representer(datetime.datetime, SafeRepresenter.represent_datetime)
SafeRepresenter.add_representer(None, SafeRepresenter.represent_undefined)
class Representer(SafeRepresenter):
if PY2:
def represent_str(self, data):
# type: (Any) -> Any
tag = None
style = None
try:
data = unicode(data, "ascii")
tag = u"tag:yaml.org,2002:str"
except UnicodeDecodeError:
try:
data = unicode(data, "utf-8")
tag = u"tag:yaml.org,2002:python/str"
except UnicodeDecodeError:
data = data.encode("base64")
tag = u"tag:yaml.org,2002:binary"
style = "|"
return self.represent_scalar(tag, data, style=style)
def represent_unicode(self, data):
# type: (Any) -> Any
tag = None
try:
data.encode("ascii")
tag = u"tag:yaml.org,2002:python/unicode"
except UnicodeEncodeError:
tag = u"tag:yaml.org,2002:str"
return self.represent_scalar(tag, data)
def represent_long(self, data):
# type: (Any) -> Any
tag = u"tag:yaml.org,2002:int"
if int(data) is not data:
tag = u"tag:yaml.org,2002:python/long"
return self.represent_scalar(tag, to_unicode(data))
def represent_complex(self, data):
# type: (Any) -> Any
if data.imag == 0.0:
data = u"%r" % data.real
elif data.real == 0.0:
data = u"%rj" % data.imag
elif data.imag > 0:
data = u"%r+%rj" % (data.real, data.imag)
else:
data = u"%r%rj" % (data.real, data.imag)
return self.represent_scalar(u"tag:yaml.org,2002:python/complex", data)
def represent_tuple(self, data):
# type: (Any) -> Any
return self.represent_sequence(u"tag:yaml.org,2002:python/tuple", data)
def represent_name(self, data):
# type: (Any) -> Any
try:
name = u"%s.%s" % (data.__module__, data.__qualname__)
except AttributeError:
# probably PY2
name = u"%s.%s" % (data.__module__, data.__name__)
return self.represent_scalar(u"tag:yaml.org,2002:python/name:" + name, "")
def represent_module(self, data):
# type: (Any) -> Any
return self.represent_scalar(
u"tag:yaml.org,2002:python/module:" + data.__name__, ""
)
if PY2:
def represent_instance(self, data):
# type: (Any) -> Any
# For instances of classic classes, we use __getinitargs__ and
# __getstate__ to serialize the data.
# If data.__getinitargs__ exists, the object must be reconstructed
# by calling cls(**args), where args is a tuple returned by
# __getinitargs__. Otherwise, the cls.__init__ method should never
# be called and the class instance is created by instantiating a
# trivial class and assigning to the instance's __class__ variable.
# If data.__getstate__ exists, it returns the state of the object.
# Otherwise, the state of the object is data.__dict__.
# We produce either a !!python/object or !!python/object/new node.
# If data.__getinitargs__ does not exist and state is a dictionary,
# we produce a !!python/object node . Otherwise we produce a
# !!python/object/new node.
cls = data.__class__
class_name = u"%s.%s" % (cls.__module__, cls.__name__)
args = None
state = None
if hasattr(data, "__getinitargs__"):
args = list(data.__getinitargs__())
if hasattr(data, "__getstate__"):
state = data.__getstate__()
else:
state = data.__dict__
if args is None and isinstance(state, dict):
return self.represent_mapping(
u"tag:yaml.org,2002:python/object:" + class_name, state
)
if isinstance(state, dict) and not state:
return self.represent_sequence(
u"tag:yaml.org,2002:python/object/new:" + class_name, args
)
value = {}
if bool(args):
value["args"] = args
value["state"] = state # type: ignore
return self.represent_mapping(
u"tag:yaml.org,2002:python/object/new:" + class_name, value
)
def represent_object(self, data):
# type: (Any) -> Any
# We use __reduce__ API to save the data. data.__reduce__ returns
# a tuple of length 2-5:
# (function, args, state, listitems, dictitems)
# For reconstructing, we calls function(*args), then set its state,
# listitems, and dictitems if they are not None.
# A special case is when function.__name__ == '__newobj__'. In this
# case we create the object with args[0].__new__(*args).
# Another special case is when __reduce__ returns a string - we don't
# support it.
# We produce a !!python/object, !!python/object/new or
# !!python/object/apply node.
cls = type(data)
if cls in copyreg.dispatch_table:
reduce = copyreg.dispatch_table[cls](data)
elif hasattr(data, "__reduce_ex__"):
reduce = data.__reduce_ex__(2)
elif hasattr(data, "__reduce__"):
reduce = data.__reduce__()
else:
raise RepresenterError("cannot represent object: %r" % (data,))
reduce = (list(reduce) + [None] * 5)[:5]
function, args, state, listitems, dictitems = reduce
args = list(args)
if state is None:
state = {}
if listitems is not None:
listitems = list(listitems)
if dictitems is not None:
dictitems = dict(dictitems)
if function.__name__ == "__newobj__":
function = args[0]
args = args[1:]
tag = u"tag:yaml.org,2002:python/object/new:"
newobj = True
else:
tag = u"tag:yaml.org,2002:python/object/apply:"
newobj = False
try:
function_name = u"%s.%s" % (function.__module__, function.__qualname__)
except AttributeError:
# probably PY2
function_name = u"%s.%s" % (function.__module__, function.__name__)
if (
not args
and not listitems
and not dictitems
and isinstance(state, dict)
and newobj
):
return self.represent_mapping(
u"tag:yaml.org,2002:python/object:" + function_name, state
)
if not listitems and not dictitems and isinstance(state, dict) and not state:
return self.represent_sequence(tag + function_name, args)
value = {}
if args:
value["args"] = args
if state or not isinstance(state, dict):
value["state"] = state
if listitems:
value["listitems"] = listitems
if dictitems:
value["dictitems"] = dictitems
return self.represent_mapping(tag + function_name, value)
if PY2:
Representer.add_representer(str, Representer.represent_str)
Representer.add_representer(unicode, Representer.represent_unicode)
Representer.add_representer(long, Representer.represent_long)
Representer.add_representer(complex, Representer.represent_complex)
Representer.add_representer(tuple, Representer.represent_tuple)
Representer.add_representer(type, Representer.represent_name)
if PY2:
Representer.add_representer(types.ClassType, Representer.represent_name)
Representer.add_representer(types.FunctionType, Representer.represent_name)
Representer.add_representer(types.BuiltinFunctionType, Representer.represent_name)
Representer.add_representer(types.ModuleType, Representer.represent_module)
if PY2:
Representer.add_multi_representer(
types.InstanceType, Representer.represent_instance
)
Representer.add_multi_representer(object, Representer.represent_object)
Representer.add_multi_representer(type, Representer.represent_name)
from .comments import (
CommentedMap,
CommentedOrderedMap,
CommentedSeq,
CommentedKeySeq,
CommentedKeyMap,
CommentedSet,
comment_attrib,
merge_attrib,
TaggedScalar,
) # NOQA
class RoundTripRepresenter(SafeRepresenter):
# need to add type here and write out the .comment
# in serializer and emitter
def __init__(self, default_style=None, default_flow_style=None, dumper=None):
# type: (Any, Any, Any) -> None
if not hasattr(dumper, "typ") and default_flow_style is None:
default_flow_style = False
SafeRepresenter.__init__(
self,
default_style=default_style,
default_flow_style=default_flow_style,
dumper=dumper,
)
def ignore_aliases(self, data):
# type: (Any) -> bool
try:
if data.anchor is not None and data.anchor.value is not None:
return False
except AttributeError:
pass
return SafeRepresenter.ignore_aliases(self, data)
def represent_none(self, data):
# type: (Any) -> Any
if (
len(self.represented_objects) == 0
and not self.serializer.use_explicit_start
):
# this will be open ended (although it is not yet)
return self.represent_scalar(u"tag:yaml.org,2002:null", u"null")
return self.represent_scalar(u"tag:yaml.org,2002:null", "")
def represent_literal_scalarstring(self, data):
# type: (Any) -> Any
tag = None
style = "|"
anchor = data.yaml_anchor(any=True)
if PY2 and not isinstance(data, unicode):
data = unicode(data, "ascii")
tag = u"tag:yaml.org,2002:str"
return self.represent_scalar(tag, data, style=style, anchor=anchor)
represent_preserved_scalarstring = represent_literal_scalarstring
def represent_folded_scalarstring(self, data):
# type: (Any) -> Any
tag = None
style = ">"
anchor = data.yaml_anchor(any=True)
for fold_pos in reversed(getattr(data, "fold_pos", [])):
if (
data[fold_pos] == " "
and (fold_pos > 0 and not data[fold_pos - 1].isspace())
and (fold_pos < len(data) and not data[fold_pos + 1].isspace())
):
data = data[:fold_pos] + "\a" + data[fold_pos:]
if PY2 and not isinstance(data, unicode):
data = unicode(data, "ascii")
tag = u"tag:yaml.org,2002:str"
return self.represent_scalar(tag, data, style=style, anchor=anchor)
def represent_single_quoted_scalarstring(self, data):
# type: (Any) -> Any
tag = None
style = "'"
anchor = data.yaml_anchor(any=True)
if PY2 and not isinstance(data, unicode):
data = unicode(data, "ascii")
tag = u"tag:yaml.org,2002:str"
return self.represent_scalar(tag, data, style=style, anchor=anchor)
def represent_double_quoted_scalarstring(self, data):
# type: (Any) -> Any
tag = None
style = '"'
anchor = data.yaml_anchor(any=True)
if PY2 and not isinstance(data, unicode):
data = unicode(data, "ascii")
tag = u"tag:yaml.org,2002:str"
return self.represent_scalar(tag, data, style=style, anchor=anchor)
def represent_plain_scalarstring(self, data):
# type: (Any) -> Any
tag = None
style = ""
anchor = data.yaml_anchor(any=True)
if PY2 and not isinstance(data, unicode):
data = unicode(data, "ascii")
tag = u"tag:yaml.org,2002:str"
return self.represent_scalar(tag, data, style=style, anchor=anchor)
def insert_underscore(self, prefix, s, underscore, anchor=None):
# type: (Any, Any, Any, Any) -> Any
if underscore is None:
return self.represent_scalar(
u"tag:yaml.org,2002:int", prefix + s, anchor=anchor
)
if underscore[0]:
sl = list(s)
pos = len(s) - underscore[0]
while pos > 0:
sl.insert(pos, "_")
pos -= underscore[0]
s = "".join(sl)
if underscore[1]:
s = "_" + s
if underscore[2]:
s += "_"
return self.represent_scalar(
u"tag:yaml.org,2002:int", prefix + s, anchor=anchor
)
def represent_scalar_int(self, data):
# type: (Any) -> Any
if data._width is not None:
s = "{:0{}d}".format(data, data._width)
else:
s = format(data, "d")
anchor = data.yaml_anchor(any=True)
return self.insert_underscore("", s, data._underscore, anchor=anchor)
def represent_binary_int(self, data):
# type: (Any) -> Any
if data._width is not None:
# cannot use '{:#0{}b}', that strips the zeros
s = "{:0{}b}".format(data, data._width)
else:
s = format(data, "b")
anchor = data.yaml_anchor(any=True)
return self.insert_underscore("0b", s, data._underscore, anchor=anchor)
def represent_octal_int(self, data):
# type: (Any) -> Any
if data._width is not None:
# cannot use '{:#0{}o}', that strips the zeros
s = "{:0{}o}".format(data, data._width)
else:
s = format(data, "o")
anchor = data.yaml_anchor(any=True)
return self.insert_underscore("0o", s, data._underscore, anchor=anchor)
def represent_hex_int(self, data):
# type: (Any) -> Any
if data._width is not None:
# cannot use '{:#0{}x}', that strips the zeros
s = "{:0{}x}".format(data, data._width)
else:
s = format(data, "x")
anchor = data.yaml_anchor(any=True)
return self.insert_underscore("0x", s, data._underscore, anchor=anchor)
def represent_hex_caps_int(self, data):
# type: (Any) -> Any
if data._width is not None:
# cannot use '{:#0{}X}', that strips the zeros
s = "{:0{}X}".format(data, data._width)
else:
s = format(data, "X")
anchor = data.yaml_anchor(any=True)
return self.insert_underscore("0x", s, data._underscore, anchor=anchor)
def represent_scalar_float(self, data):
# type: (Any) -> Any
""" this is way more complicated """
value = None
anchor = data.yaml_anchor(any=True)
if data != data or (data == 0.0 and data == 1.0):
value = u".nan"
elif data == self.inf_value:
value = u".inf"
elif data == -self.inf_value:
value = u"-.inf"
if value:
return self.represent_scalar(
u"tag:yaml.org,2002:float", value, anchor=anchor
)
if data._exp is None and data._prec > 0 and data._prec == data._width - 1:
# no exponent, but trailing dot
value = u"{}{:d}.".format(
data._m_sign if data._m_sign else "", abs(int(data))
)
elif data._exp is None:
# no exponent, "normal" dot
prec = data._prec
ms = data._m_sign if data._m_sign else ""
# -1 for the dot
value = u"{}{:0{}.{}f}".format(
ms, abs(data), data._width - len(ms), data._width - prec - 1
)
if prec == 0 or (prec == 1 and ms != ""):
value = value.replace(u"0.", u".")
while len(value) < data._width:
value += u"0"
else:
# exponent
m, es = u"{:{}.{}e}".format(
# data, data._width, data._width - data._prec + (1 if data._m_sign else 0)
data,
data._width,
data._width + (1 if data._m_sign else 0),
).split("e")
w = data._width if data._prec > 0 else (data._width + 1)
if data < 0:
w += 1
m = m[:w]
e = int(es)
m1, m2 = m.split(".") # always second?
while len(m1) + len(m2) < data._width - (1 if data._prec >= 0 else 0):
m2 += u"0"
if data._m_sign and data > 0:
m1 = "+" + m1
esgn = u"+" if data._e_sign else ""
if data._prec < 0: # mantissa without dot
if m2 != u"0":
e -= len(m2)
else:
m2 = ""
while (len(m1) + len(m2) - (1 if data._m_sign else 0)) < data._width:
m2 += u"0"
e -= 1
value = (
m1 + m2 + data._exp + u"{:{}0{}d}".format(e, esgn, data._e_width)
)
elif data._prec == 0: # mantissa with trailing dot
e -= len(m2)
value = (
m1
+ m2
+ u"."
+ data._exp
+ u"{:{}0{}d}".format(e, esgn, data._e_width)
)
else:
if data._m_lead0 > 0:
m2 = u"0" * (data._m_lead0 - 1) + m1 + m2
m1 = u"0"
m2 = m2[: -data._m_lead0] # these should be zeros
e += data._m_lead0
while len(m1) < data._prec:
m1 += m2[0]
m2 = m2[1:]
e -= 1
value = (
m1
+ u"."
+ m2
+ data._exp
+ u"{:{}0{}d}".format(e, esgn, data._e_width)
)
if value is None:
value = to_unicode(repr(data)).lower()
return self.represent_scalar(u"tag:yaml.org,2002:float", value, anchor=anchor)
def represent_sequence(self, tag, sequence, flow_style=None):
# type: (Any, Any, Any) -> Any
value = [] # type: List[Any]
# if the flow_style is None, the flow style tacked on to the object
# explicitly will be taken. If that is None as well the default flow
# style rules
try:
flow_style = sequence.fa.flow_style(flow_style)
except AttributeError:
flow_style = flow_style
try:
anchor = sequence.yaml_anchor()
except AttributeError:
anchor = None
node = SequenceNode(tag, value, flow_style=flow_style, anchor=anchor)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
try:
comment = getattr(sequence, comment_attrib)
node.comment = comment.comment
# reset any comment already printed information
if node.comment and node.comment[1]:
for ct in node.comment[1]:
ct.reset()
item_comments = comment.items
for v in item_comments.values():
if v and v[1]:
for ct in v[1]:
ct.reset()
item_comments = comment.items
node.comment = comment.comment
try:
node.comment.append(comment.end)
except AttributeError:
pass
except AttributeError:
item_comments = {}
for idx, item in enumerate(sequence):
node_item = self.represent_data(item)
self.merge_comments(node_item, item_comments.get(idx))
if not (isinstance(node_item, ScalarNode) and not node_item.style):
best_style = False
value.append(node_item)
if flow_style is None:
if len(sequence) != 0 and self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def merge_comments(self, node, comments):
# type: (Any, Any) -> Any
if comments is None:
assert hasattr(node, "comment")
return node
if getattr(node, "comment", None) is not None:
for idx, val in enumerate(comments):
if idx >= len(node.comment):
continue
nc = node.comment[idx]
if nc is not None:
assert val is None or val == nc
comments[idx] = nc
node.comment = comments
return node
def represent_key(self, data):
# type: (Any) -> Any
if isinstance(data, CommentedKeySeq):
self.alias_key = None
return self.represent_sequence(
u"tag:yaml.org,2002:seq", data, flow_style=True
)
if isinstance(data, CommentedKeyMap):
self.alias_key = None
return self.represent_mapping(
u"tag:yaml.org,2002:map", data, flow_style=True
)
return SafeRepresenter.represent_key(self, data)
def represent_mapping(self, tag, mapping, flow_style=None):
# type: (Any, Any, Any) -> Any
value = [] # type: List[Any]
try:
flow_style = mapping.fa.flow_style(flow_style)
except AttributeError:
flow_style = flow_style
try:
anchor = mapping.yaml_anchor()
except AttributeError:
anchor = None
node = MappingNode(tag, value, flow_style=flow_style, anchor=anchor)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
# no sorting! !!
try:
comment = getattr(mapping, comment_attrib)
node.comment = comment.comment
if node.comment and node.comment[1]:
for ct in node.comment[1]:
ct.reset()
item_comments = comment.items
for v in item_comments.values():
if v and v[1]:
for ct in v[1]:
ct.reset()
try:
node.comment.append(comment.end)
except AttributeError:
pass
except AttributeError:
item_comments = {}
merge_list = [m[1] for m in getattr(mapping, merge_attrib, [])]
try:
merge_pos = getattr(mapping, merge_attrib, [[0]])[0][0]
except IndexError:
merge_pos = 0
item_count = 0
if bool(merge_list):
items = mapping.non_merged_items()
else:
items = mapping.items()
for item_key, item_value in items:
item_count += 1
node_key = self.represent_key(item_key)
node_value = self.represent_data(item_value)
item_comment = item_comments.get(item_key)
if item_comment:
assert getattr(node_key, "comment", None) is None
node_key.comment = item_comment[:2]
nvc = getattr(node_value, "comment", None)
if nvc is not None: # end comment already there
nvc[0] = item_comment[2]
nvc[1] = item_comment[3]
else:
node_value.comment = item_comment[2:]
if not (isinstance(node_key, ScalarNode) and not node_key.style):
best_style = False
if not (isinstance(node_value, ScalarNode) and not node_value.style):
best_style = False
value.append((node_key, node_value))
if flow_style is None:
if (
(item_count != 0) or bool(merge_list)
) and self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
if bool(merge_list):
# because of the call to represent_data here, the anchors
# are marked as being used and thereby created
if len(merge_list) == 1:
arg = self.represent_data(merge_list[0])
else:
arg = self.represent_data(merge_list)
arg.flow_style = True
value.insert(merge_pos, (ScalarNode(u"tag:yaml.org,2002:merge", "<<"), arg))
return node
def represent_omap(self, tag, omap, flow_style=None):
# type: (Any, Any, Any) -> Any
value = [] # type: List[Any]
try:
flow_style = omap.fa.flow_style(flow_style)
except AttributeError:
flow_style = flow_style
try:
anchor = omap.yaml_anchor()
except AttributeError:
anchor = None
node = SequenceNode(tag, value, flow_style=flow_style, anchor=anchor)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
try:
comment = getattr(omap, comment_attrib)
node.comment = comment.comment
if node.comment and node.comment[1]:
for ct in node.comment[1]:
ct.reset()
item_comments = comment.items
for v in item_comments.values():
if v and v[1]:
for ct in v[1]:
ct.reset()
try:
node.comment.append(comment.end)
except AttributeError:
pass
except AttributeError:
item_comments = {}
for item_key in omap:
item_val = omap[item_key]
node_item = self.represent_data({item_key: item_val})
# node_item.flow_style = False
# node item has two scalars in value: node_key and node_value
item_comment = item_comments.get(item_key)
if item_comment:
if item_comment[1]:
node_item.comment = [None, item_comment[1]]
assert getattr(node_item.value[0][0], "comment", None) is None
node_item.value[0][0].comment = [item_comment[0], None]
nvc = getattr(node_item.value[0][1], "comment", None)
if nvc is not None: # end comment already there
nvc[0] = item_comment[2]
nvc[1] = item_comment[3]
else:
node_item.value[0][1].comment = item_comment[2:]
# if not (isinstance(node_item, ScalarNode) \
# and not node_item.style):
# best_style = False
value.append(node_item)
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def represent_set(self, setting):
# type: (Any) -> Any
flow_style = False
tag = u"tag:yaml.org,2002:set"
# return self.represent_mapping(tag, value)
value = [] # type: List[Any]
flow_style = setting.fa.flow_style(flow_style)
try:
anchor = setting.yaml_anchor()
except AttributeError:
anchor = None
node = MappingNode(tag, value, flow_style=flow_style, anchor=anchor)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
# no sorting! !!
try:
comment = getattr(setting, comment_attrib)
node.comment = comment.comment
if node.comment and node.comment[1]:
for ct in node.comment[1]:
ct.reset()
item_comments = comment.items
for v in item_comments.values():
if v and v[1]:
for ct in v[1]:
ct.reset()
try:
node.comment.append(comment.end)
except AttributeError:
pass
except AttributeError:
item_comments = {}
for item_key in setting.odict:
node_key = self.represent_key(item_key)
node_value = self.represent_data(None)
item_comment = item_comments.get(item_key)
if item_comment:
assert getattr(node_key, "comment", None) is None
node_key.comment = item_comment[:2]
node_key.style = node_value.style = "?"
if not (isinstance(node_key, ScalarNode) and not node_key.style):
best_style = False
if not (isinstance(node_value, ScalarNode) and not node_value.style):
best_style = False
value.append((node_key, node_value))
best_style = best_style
return node
def represent_dict(self, data):
# type: (Any) -> Any
"""write out tag if saved on loading"""
try:
t = data.tag.value
except AttributeError:
t = None
if t:
if t.startswith("!!"):
tag = "tag:yaml.org,2002:" + t[2:]
else:
tag = t
else:
tag = u"tag:yaml.org,2002:map"
return self.represent_mapping(tag, data)
def represent_list(self, data):
# type: (Any) -> Any
try:
t = data.tag.value
except AttributeError:
t = None
if t:
if t.startswith("!!"):
tag = "tag:yaml.org,2002:" + t[2:]
else:
tag = t
else:
tag = u"tag:yaml.org,2002:seq"
return self.represent_sequence(tag, data)
def represent_datetime(self, data):
# type: (Any) -> Any
inter = "T" if data._yaml["t"] else " "
_yaml = data._yaml
if _yaml["delta"]:
data += _yaml["delta"]
value = data.isoformat(inter)
else:
value = data.isoformat(inter)
if _yaml["tz"]:
value += _yaml["tz"]
return self.represent_scalar(u"tag:yaml.org,2002:timestamp", to_unicode(value))
def represent_tagged_scalar(self, data):
# type: (Any) -> Any
try:
tag = data.tag.value
except AttributeError:
tag = None
try:
anchor = data.yaml_anchor()
except AttributeError:
anchor = None
return self.represent_scalar(tag, data.value, style=data.style, anchor=anchor)
def represent_scalar_bool(self, data):
# type: (Any) -> Any
try:
anchor = data.yaml_anchor()
except AttributeError:
anchor = None
return SafeRepresenter.represent_bool(self, data, anchor=anchor)
RoundTripRepresenter.add_representer(type(None), RoundTripRepresenter.represent_none)
RoundTripRepresenter.add_representer(
LiteralScalarString, RoundTripRepresenter.represent_literal_scalarstring
)
RoundTripRepresenter.add_representer(
FoldedScalarString, RoundTripRepresenter.represent_folded_scalarstring
)
RoundTripRepresenter.add_representer(
SingleQuotedScalarString, RoundTripRepresenter.represent_single_quoted_scalarstring
)
RoundTripRepresenter.add_representer(
DoubleQuotedScalarString, RoundTripRepresenter.represent_double_quoted_scalarstring
)
RoundTripRepresenter.add_representer(
PlainScalarString, RoundTripRepresenter.represent_plain_scalarstring
)
# RoundTripRepresenter.add_representer(tuple, Representer.represent_tuple)
RoundTripRepresenter.add_representer(
ScalarInt, RoundTripRepresenter.represent_scalar_int
)
RoundTripRepresenter.add_representer(
BinaryInt, RoundTripRepresenter.represent_binary_int
)
RoundTripRepresenter.add_representer(OctalInt, RoundTripRepresenter.represent_octal_int)
RoundTripRepresenter.add_representer(HexInt, RoundTripRepresenter.represent_hex_int)
RoundTripRepresenter.add_representer(
HexCapsInt, RoundTripRepresenter.represent_hex_caps_int
)
RoundTripRepresenter.add_representer(
ScalarFloat, RoundTripRepresenter.represent_scalar_float
)
RoundTripRepresenter.add_representer(
ScalarBoolean, RoundTripRepresenter.represent_scalar_bool
)
RoundTripRepresenter.add_representer(CommentedSeq, RoundTripRepresenter.represent_list)
RoundTripRepresenter.add_representer(CommentedMap, RoundTripRepresenter.represent_dict)
RoundTripRepresenter.add_representer(
CommentedOrderedMap, RoundTripRepresenter.represent_ordereddict
)
if sys.version_info >= (2, 7):
import collections
RoundTripRepresenter.add_representer(
collections.OrderedDict, RoundTripRepresenter.represent_ordereddict
)
RoundTripRepresenter.add_representer(CommentedSet, RoundTripRepresenter.represent_set)
RoundTripRepresenter.add_representer(
TaggedScalar, RoundTripRepresenter.represent_tagged_scalar
)
RoundTripRepresenter.add_representer(TimeStamp, RoundTripRepresenter.represent_datetime)
srsly-release-v2.5.1/srsly/ruamel_yaml/resolver.py 0000775 0000000 0000000 00000036140 14742310675 0022415 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
import re
if False: # MYPY
from typing import Any, Dict, List, Union, Text, Optional # NOQA
from .compat import VersionType # NOQA
from .compat import string_types, _DEFAULT_YAML_VERSION # NOQA
from .error import * # NOQA
from .nodes import MappingNode, ScalarNode, SequenceNode # NOQA
from .util import RegExp # NOQA
__all__ = ["BaseResolver", "Resolver", "VersionedResolver"]
# fmt: off
# resolvers consist of
# - a list of applicable version
# - a tag
# - a regexp
# - a list of first characters to match
implicit_resolvers = [
([(1, 2)],
u'tag:yaml.org,2002:bool',
RegExp(u'''^(?:true|True|TRUE|false|False|FALSE)$''', re.X),
list(u'tTfF')),
([(1, 1)],
u'tag:yaml.org,2002:bool',
RegExp(u'''^(?:y|Y|yes|Yes|YES|n|N|no|No|NO
|true|True|TRUE|false|False|FALSE
|on|On|ON|off|Off|OFF)$''', re.X),
list(u'yYnNtTfFoO')),
([(1, 2)],
u'tag:yaml.org,2002:float',
RegExp(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|[-+]?\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.')),
([(1, 1)],
u'tag:yaml.org,2002:float',
RegExp(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* # sexagesimal float
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.')),
([(1, 2)],
u'tag:yaml.org,2002:int',
RegExp(u'''^(?:[-+]?0b[0-1_]+
|[-+]?0o?[0-7_]+
|[-+]?[0-9_]+
|[-+]?0x[0-9a-fA-F_]+)$''', re.X),
list(u'-+0123456789')),
([(1, 1)],
u'tag:yaml.org,2002:int',
RegExp(u'''^(?:[-+]?0b[0-1_]+
|[-+]?0?[0-7_]+
|[-+]?(?:0|[1-9][0-9_]*)
|[-+]?0x[0-9a-fA-F_]+
|[-+]?[1-9][0-9_]*(?::[0-5]?[0-9])+)$''', re.X), # sexagesimal int
list(u'-+0123456789')),
([(1, 2), (1, 1)],
u'tag:yaml.org,2002:merge',
RegExp(u'^(?:<<)$'),
[u'<']),
([(1, 2), (1, 1)],
u'tag:yaml.org,2002:null',
RegExp(u'''^(?: ~
|null|Null|NULL
| )$''', re.X),
[u'~', u'n', u'N', u'']),
([(1, 2), (1, 1)],
u'tag:yaml.org,2002:timestamp',
RegExp(u'''^(?:[0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9]
|[0-9][0-9][0-9][0-9] -[0-9][0-9]? -[0-9][0-9]?
(?:[Tt]|[ \\t]+)[0-9][0-9]?
:[0-9][0-9] :[0-9][0-9] (?:\\.[0-9]*)?
(?:[ \\t]*(?:Z|[-+][0-9][0-9]?(?::[0-9][0-9])?))?)$''', re.X),
list(u'0123456789')),
([(1, 2), (1, 1)],
u'tag:yaml.org,2002:value',
RegExp(u'^(?:=)$'),
[u'=']),
# The following resolver is only for documentation purposes. It cannot work
# because plain scalars cannot start with '!', '&', or '*'.
([(1, 2), (1, 1)],
u'tag:yaml.org,2002:yaml',
RegExp(u'^(?:!|&|\\*)$'),
list(u'!&*')),
]
# fmt: on
class ResolverError(YAMLError):
pass
class BaseResolver(object):
DEFAULT_SCALAR_TAG = u"tag:yaml.org,2002:str"
DEFAULT_SEQUENCE_TAG = u"tag:yaml.org,2002:seq"
DEFAULT_MAPPING_TAG = u"tag:yaml.org,2002:map"
yaml_implicit_resolvers = {} # type: Dict[Any, Any]
yaml_path_resolvers = {} # type: Dict[Any, Any]
def __init__(self, loadumper=None):
# type: (Any, Any) -> None
self.loadumper = loadumper
if (
self.loadumper is not None
and getattr(self.loadumper, "_resolver", None) is None
):
self.loadumper._resolver = self.loadumper
self._loader_version = None # type: Any
self.resolver_exact_paths = [] # type: List[Any]
self.resolver_prefix_paths = [] # type: List[Any]
@property
def parser(self):
# type: () -> Any
if self.loadumper is not None:
if hasattr(self.loadumper, "typ"):
return self.loadumper.parser
return self.loadumper._parser
return None
@classmethod
def add_implicit_resolver_base(cls, tag, regexp, first):
# type: (Any, Any, Any) -> None
if "yaml_implicit_resolvers" not in cls.__dict__:
# deepcopy doesn't work here
cls.yaml_implicit_resolvers = dict(
(k, cls.yaml_implicit_resolvers[k][:])
for k in cls.yaml_implicit_resolvers
)
if first is None:
first = [None]
for ch in first:
cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp))
@classmethod
def add_implicit_resolver(cls, tag, regexp, first):
# type: (Any, Any, Any) -> None
if "yaml_implicit_resolvers" not in cls.__dict__:
# deepcopy doesn't work here
cls.yaml_implicit_resolvers = dict(
(k, cls.yaml_implicit_resolvers[k][:])
for k in cls.yaml_implicit_resolvers
)
if first is None:
first = [None]
for ch in first:
cls.yaml_implicit_resolvers.setdefault(ch, []).append((tag, regexp))
implicit_resolvers.append(([(1, 2), (1, 1)], tag, regexp, first))
# @classmethod
# def add_implicit_resolver(cls, tag, regexp, first):
@classmethod
def add_path_resolver(cls, tag, path, kind=None):
# type: (Any, Any, Any) -> None
# Note: `add_path_resolver` is experimental. The API could be changed.
# `new_path` is a pattern that is matched against the path from the
# root to the node that is being considered. `node_path` elements are
# tuples `(node_check, index_check)`. `node_check` is a node class:
# `ScalarNode`, `SequenceNode`, `MappingNode` or `None`. `None`
# matches any kind of a node. `index_check` could be `None`, a boolean
# value, a string value, or a number. `None` and `False` match against
# any _value_ of sequence and mapping nodes. `True` matches against
# any _key_ of a mapping node. A string `index_check` matches against
# a mapping value that corresponds to a scalar key which content is
# equal to the `index_check` value. An integer `index_check` matches
# against a sequence value with the index equal to `index_check`.
if "yaml_path_resolvers" not in cls.__dict__:
cls.yaml_path_resolvers = cls.yaml_path_resolvers.copy()
new_path = [] # type: List[Any]
for element in path:
if isinstance(element, (list, tuple)):
if len(element) == 2:
node_check, index_check = element
elif len(element) == 1:
node_check = element[0]
index_check = True
else:
raise ResolverError("Invalid path element: %s" % (element,))
else:
node_check = None
index_check = element
if node_check is str:
node_check = ScalarNode
elif node_check is list:
node_check = SequenceNode
elif node_check is dict:
node_check = MappingNode
elif (
node_check not in [ScalarNode, SequenceNode, MappingNode]
and not isinstance(node_check, string_types)
and node_check is not None
):
raise ResolverError("Invalid node checker: %s" % (node_check,))
if (
not isinstance(index_check, (string_types, int))
and index_check is not None
):
raise ResolverError("Invalid index checker: %s" % (index_check,))
new_path.append((node_check, index_check))
if kind is str:
kind = ScalarNode
elif kind is list:
kind = SequenceNode
elif kind is dict:
kind = MappingNode
elif kind not in [ScalarNode, SequenceNode, MappingNode] and kind is not None:
raise ResolverError("Invalid node kind: %s" % (kind,))
cls.yaml_path_resolvers[tuple(new_path), kind] = tag
def descend_resolver(self, current_node, current_index):
# type: (Any, Any) -> None
if not self.yaml_path_resolvers:
return
exact_paths = {}
prefix_paths = []
if current_node:
depth = len(self.resolver_prefix_paths)
for path, kind in self.resolver_prefix_paths[-1]:
if self.check_resolver_prefix(
depth, path, kind, current_node, current_index
):
if len(path) > depth:
prefix_paths.append((path, kind))
else:
exact_paths[kind] = self.yaml_path_resolvers[path, kind]
else:
for path, kind in self.yaml_path_resolvers:
if not path:
exact_paths[kind] = self.yaml_path_resolvers[path, kind]
else:
prefix_paths.append((path, kind))
self.resolver_exact_paths.append(exact_paths)
self.resolver_prefix_paths.append(prefix_paths)
def ascend_resolver(self):
# type: () -> None
if not self.yaml_path_resolvers:
return
self.resolver_exact_paths.pop()
self.resolver_prefix_paths.pop()
def check_resolver_prefix(self, depth, path, kind, current_node, current_index):
# type: (int, Text, Any, Any, Any) -> bool
node_check, index_check = path[depth - 1]
if isinstance(node_check, string_types):
if current_node.tag != node_check:
return False
elif node_check is not None:
if not isinstance(current_node, node_check):
return False
if index_check is True and current_index is not None:
return False
if (index_check is False or index_check is None) and current_index is None:
return False
if isinstance(index_check, string_types):
if not (
isinstance(current_index, ScalarNode)
and index_check == current_index.value
):
return False
elif isinstance(index_check, int) and not isinstance(index_check, bool):
if index_check != current_index:
return False
return True
def resolve(self, kind, value, implicit):
# type: (Any, Any, Any) -> Any
if kind is ScalarNode and implicit[0]:
if value == "":
resolvers = self.yaml_implicit_resolvers.get("", [])
else:
resolvers = self.yaml_implicit_resolvers.get(value[0], [])
resolvers += self.yaml_implicit_resolvers.get(None, [])
for tag, regexp in resolvers:
if regexp.match(value):
return tag
implicit = implicit[1]
if bool(self.yaml_path_resolvers):
exact_paths = self.resolver_exact_paths[-1]
if kind in exact_paths:
return exact_paths[kind]
if None in exact_paths:
return exact_paths[None]
if kind is ScalarNode:
return self.DEFAULT_SCALAR_TAG
elif kind is SequenceNode:
return self.DEFAULT_SEQUENCE_TAG
elif kind is MappingNode:
return self.DEFAULT_MAPPING_TAG
@property
def processing_version(self):
# type: () -> Any
return None
class Resolver(BaseResolver):
pass
for ir in implicit_resolvers:
if (1, 2) in ir[0]:
Resolver.add_implicit_resolver_base(*ir[1:])
class VersionedResolver(BaseResolver):
"""
contrary to the "normal" resolver, the smart resolver delays loading
the pattern matching rules. That way it can decide to load 1.1 rules
or the (default) 1.2 rules, that no longer support octal without 0o, sexagesimals
and Yes/No/On/Off booleans.
"""
def __init__(self, version=None, loader=None, loadumper=None):
# type: (Optional[VersionType], Any, Any) -> None
if loader is None and loadumper is not None:
loader = loadumper
BaseResolver.__init__(self, loader)
self._loader_version = self.get_loader_version(version)
self._version_implicit_resolver = {} # type: Dict[Any, Any]
def add_version_implicit_resolver(self, version, tag, regexp, first):
# type: (VersionType, Any, Any, Any) -> None
if first is None:
first = [None]
impl_resolver = self._version_implicit_resolver.setdefault(version, {})
for ch in first:
impl_resolver.setdefault(ch, []).append((tag, regexp))
def get_loader_version(self, version):
# type: (Optional[VersionType]) -> Any
if version is None or isinstance(version, tuple):
return version
if isinstance(version, list):
return tuple(version)
# assume string
return tuple(map(int, version.split(u".")))
@property
def versioned_resolver(self):
# type: () -> Any
"""
select the resolver based on the version we are parsing
"""
version = self.processing_version
if version not in self._version_implicit_resolver:
for x in implicit_resolvers:
if version in x[0]:
self.add_version_implicit_resolver(version, x[1], x[2], x[3])
return self._version_implicit_resolver[version]
def resolve(self, kind, value, implicit):
# type: (Any, Any, Any) -> Any
if kind is ScalarNode and implicit[0]:
if value == "":
resolvers = self.versioned_resolver.get("", [])
else:
resolvers = self.versioned_resolver.get(value[0], [])
resolvers += self.versioned_resolver.get(None, [])
for tag, regexp in resolvers:
if regexp.match(value):
return tag
implicit = implicit[1]
if bool(self.yaml_path_resolvers):
exact_paths = self.resolver_exact_paths[-1]
if kind in exact_paths:
return exact_paths[kind]
if None in exact_paths:
return exact_paths[None]
if kind is ScalarNode:
return self.DEFAULT_SCALAR_TAG
elif kind is SequenceNode:
return self.DEFAULT_SEQUENCE_TAG
elif kind is MappingNode:
return self.DEFAULT_MAPPING_TAG
@property
def processing_version(self):
# type: () -> Any
try:
version = self.loadumper._scanner.yaml_version
except AttributeError:
try:
if hasattr(self.loadumper, "typ"):
version = self.loadumper.version
else:
version = self.loadumper._serializer.use_version # dumping
except AttributeError:
version = None
if version is None:
version = self._loader_version
if version is None:
version = _DEFAULT_YAML_VERSION
return version
srsly-release-v2.5.1/srsly/ruamel_yaml/scalarbool.py 0000775 0000000 0000000 00000002760 14742310675 0022676 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division, unicode_literals
"""
You cannot subclass bool, and this is necessary for round-tripping anchored
bool values (and also if you want to preserve the original way of writing)
bool.__bases__ is type 'int', so that is what is used as the basis for ScalarBoolean as well.
You can use these in an if statement, but not when testing equivalence
"""
from .anchor import Anchor
if False: # MYPY
from typing import Text, Any, Dict, List # NOQA
__all__ = ["ScalarBoolean"]
# no need for no_limit_int -> int
class ScalarBoolean(int):
def __new__(cls, *args, **kw):
# type: (Any, Any, Any) -> Any
anchor = kw.pop("anchor", None) # type: ignore
b = int.__new__(cls, *args, **kw) # type: ignore
if anchor is not None:
b.yaml_set_anchor(anchor, always_dump=True)
return b
@property
def anchor(self):
# type: () -> Any
if not hasattr(self, Anchor.attrib):
setattr(self, Anchor.attrib, Anchor())
return getattr(self, Anchor.attrib)
def yaml_anchor(self, any=False):
# type: (bool) -> Any
if not hasattr(self, Anchor.attrib):
return None
if any or self.anchor.always_dump:
return self.anchor
return None
def yaml_set_anchor(self, value, always_dump=False):
# type: (Any, bool) -> None
self.anchor.value = value
self.anchor.always_dump = always_dump
srsly-release-v2.5.1/srsly/ruamel_yaml/scalarfloat.py 0000775 0000000 0000000 00000010643 14742310675 0023047 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division, unicode_literals
import sys
from .compat import no_limit_int # NOQA
from .anchor import Anchor
if False: # MYPY
from typing import Text, Any, Dict, List # NOQA
__all__ = ["ScalarFloat", "ExponentialFloat", "ExponentialCapsFloat"]
class ScalarFloat(float):
def __new__(cls, *args, **kw):
# type: (Any, Any, Any) -> Any
width = kw.pop("width", None) # type: ignore
prec = kw.pop("prec", None) # type: ignore
m_sign = kw.pop("m_sign", None) # type: ignore
m_lead0 = kw.pop("m_lead0", 0) # type: ignore
exp = kw.pop("exp", None) # type: ignore
e_width = kw.pop("e_width", None) # type: ignore
e_sign = kw.pop("e_sign", None) # type: ignore
underscore = kw.pop("underscore", None) # type: ignore
anchor = kw.pop("anchor", None) # type: ignore
v = float.__new__(cls, *args, **kw) # type: ignore
v._width = width
v._prec = prec
v._m_sign = m_sign
v._m_lead0 = m_lead0
v._exp = exp
v._e_width = e_width
v._e_sign = e_sign
v._underscore = underscore
if anchor is not None:
v.yaml_set_anchor(anchor, always_dump=True)
return v
def __iadd__(self, a): # type: ignore
# type: (Any) -> Any
return float(self) + a
x = type(self)(self + a)
x._width = self._width
x._underscore = (
self._underscore[:] if self._underscore is not None else None
) # NOQA
return x
def __ifloordiv__(self, a): # type: ignore
# type: (Any) -> Any
return float(self) // a
x = type(self)(self // a)
x._width = self._width
x._underscore = (
self._underscore[:] if self._underscore is not None else None
) # NOQA
return x
def __imul__(self, a): # type: ignore
# type: (Any) -> Any
return float(self) * a
x = type(self)(self * a)
x._width = self._width
x._underscore = (
self._underscore[:] if self._underscore is not None else None
) # NOQA
x._prec = self._prec # check for others
return x
def __ipow__(self, a): # type: ignore
# type: (Any) -> Any
return float(self) ** a
x = type(self)(self ** a)
x._width = self._width
x._underscore = (
self._underscore[:] if self._underscore is not None else None
) # NOQA
return x
def __isub__(self, a): # type: ignore
# type: (Any) -> Any
return float(self) - a
x = type(self)(self - a)
x._width = self._width
x._underscore = (
self._underscore[:] if self._underscore is not None else None
) # NOQA
return x
@property
def anchor(self):
# type: () -> Any
if not hasattr(self, Anchor.attrib):
setattr(self, Anchor.attrib, Anchor())
return getattr(self, Anchor.attrib)
def yaml_anchor(self, any=False):
# type: (bool) -> Any
if not hasattr(self, Anchor.attrib):
return None
if any or self.anchor.always_dump:
return self.anchor
return None
def yaml_set_anchor(self, value, always_dump=False):
# type: (Any, bool) -> None
self.anchor.value = value
self.anchor.always_dump = always_dump
def dump(self, out=sys.stdout):
# type: (Any) -> Any
out.write(
"ScalarFloat({}| w:{}, p:{}, s:{}, lz:{}, _:{}|{}, w:{}, s:{})\n".format(
self,
self._width, # type: ignore
self._prec, # type: ignore
self._m_sign, # type: ignore
self._m_lead0, # type: ignore
self._underscore, # type: ignore
self._exp, # type: ignore
self._e_width, # type: ignore
self._e_sign, # type: ignore
)
)
class ExponentialFloat(ScalarFloat):
def __new__(cls, value, width=None, underscore=None):
# type: (Any, Any, Any) -> Any
return ScalarFloat.__new__(cls, value, width=width, underscore=underscore)
class ExponentialCapsFloat(ScalarFloat):
def __new__(cls, value, width=None, underscore=None):
# type: (Any, Any, Any) -> Any
return ScalarFloat.__new__(cls, value, width=width, underscore=underscore)
srsly-release-v2.5.1/srsly/ruamel_yaml/scalarint.py 0000775 0000000 0000000 00000011111 14742310675 0022523 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division, unicode_literals
from .compat import no_limit_int # NOQA
from .anchor import Anchor
if False: # MYPY
from typing import Text, Any, Dict, List # NOQA
__all__ = ["ScalarInt", "BinaryInt", "OctalInt", "HexInt", "HexCapsInt", "DecimalInt"]
class ScalarInt(no_limit_int):
def __new__(cls, *args, **kw):
# type: (Any, Any, Any) -> Any
width = kw.pop("width", None) # type: ignore
underscore = kw.pop("underscore", None) # type: ignore
anchor = kw.pop("anchor", None) # type: ignore
v = no_limit_int.__new__(cls, *args, **kw) # type: ignore
v._width = width
v._underscore = underscore
if anchor is not None:
v.yaml_set_anchor(anchor, always_dump=True)
return v
def __iadd__(self, a): # type: ignore
# type: (Any) -> Any
x = type(self)(self + a)
x._width = self._width # type: ignore
x._underscore = ( # type: ignore
self._underscore[:]
if self._underscore is not None
else None # type: ignore
) # NOQA
return x
def __ifloordiv__(self, a): # type: ignore
# type: (Any) -> Any
x = type(self)(self // a)
x._width = self._width # type: ignore
x._underscore = ( # type: ignore
self._underscore[:]
if self._underscore is not None
else None # type: ignore
) # NOQA
return x
def __imul__(self, a): # type: ignore
# type: (Any) -> Any
x = type(self)(self * a)
x._width = self._width # type: ignore
x._underscore = ( # type: ignore
self._underscore[:]
if self._underscore is not None
else None # type: ignore
) # NOQA
return x
def __ipow__(self, a): # type: ignore
# type: (Any) -> Any
x = type(self)(self ** a)
x._width = self._width # type: ignore
x._underscore = ( # type: ignore
self._underscore[:]
if self._underscore is not None
else None # type: ignore
) # NOQA
return x
def __isub__(self, a): # type: ignore
# type: (Any) -> Any
x = type(self)(self - a)
x._width = self._width # type: ignore
x._underscore = ( # type: ignore
self._underscore[:]
if self._underscore is not None
else None # type: ignore
) # NOQA
return x
@property
def anchor(self):
# type: () -> Any
if not hasattr(self, Anchor.attrib):
setattr(self, Anchor.attrib, Anchor())
return getattr(self, Anchor.attrib)
def yaml_anchor(self, any=False):
# type: (bool) -> Any
if not hasattr(self, Anchor.attrib):
return None
if any or self.anchor.always_dump:
return self.anchor
return None
def yaml_set_anchor(self, value, always_dump=False):
# type: (Any, bool) -> None
self.anchor.value = value
self.anchor.always_dump = always_dump
class BinaryInt(ScalarInt):
def __new__(cls, value, width=None, underscore=None, anchor=None):
# type: (Any, Any, Any, Any) -> Any
return ScalarInt.__new__(
cls, value, width=width, underscore=underscore, anchor=anchor
)
class OctalInt(ScalarInt):
def __new__(cls, value, width=None, underscore=None, anchor=None):
# type: (Any, Any, Any, Any) -> Any
return ScalarInt.__new__(
cls, value, width=width, underscore=underscore, anchor=anchor
)
# mixed casing of A-F is not supported, when loading the first non digit
# determines the case
class HexInt(ScalarInt):
"""uses lower case (a-f)"""
def __new__(cls, value, width=None, underscore=None, anchor=None):
# type: (Any, Any, Any, Any) -> Any
return ScalarInt.__new__(
cls, value, width=width, underscore=underscore, anchor=anchor
)
class HexCapsInt(ScalarInt):
"""uses upper case (A-F)"""
def __new__(cls, value, width=None, underscore=None, anchor=None):
# type: (Any, Any, Any, Any) -> Any
return ScalarInt.__new__(
cls, value, width=width, underscore=underscore, anchor=anchor
)
class DecimalInt(ScalarInt):
"""needed if anchor"""
def __new__(cls, value, width=None, underscore=None, anchor=None):
# type: (Any, Any, Any, Any) -> Any
return ScalarInt.__new__(
cls, value, width=width, underscore=underscore, anchor=anchor
)
srsly-release-v2.5.1/srsly/ruamel_yaml/scalarstring.py 0000775 0000000 0000000 00000010557 14742310675 0023254 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division, unicode_literals
from .compat import text_type
from .anchor import Anchor
if False: # MYPY
from typing import Text, Any, Dict, List # NOQA
__all__ = [
"ScalarString",
"LiteralScalarString",
"FoldedScalarString",
"SingleQuotedScalarString",
"DoubleQuotedScalarString",
"PlainScalarString",
# PreservedScalarString is the old name, as it was the first to be preserved on rt,
# use LiteralScalarString instead
"PreservedScalarString",
]
class ScalarString(text_type):
__slots__ = Anchor.attrib
def __new__(cls, *args, **kw):
# type: (Any, Any) -> Any
anchor = kw.pop("anchor", None) # type: ignore
ret_val = text_type.__new__(cls, *args, **kw) # type: ignore
if anchor is not None:
ret_val.yaml_set_anchor(anchor, always_dump=True)
return ret_val
def replace(self, old, new, maxreplace=-1):
# type: (Any, Any, int) -> Any
return type(self)((text_type.replace(self, old, new, maxreplace)))
@property
def anchor(self):
# type: () -> Any
if not hasattr(self, Anchor.attrib):
setattr(self, Anchor.attrib, Anchor())
return getattr(self, Anchor.attrib)
def yaml_anchor(self, any=False):
# type: (bool) -> Any
if not hasattr(self, Anchor.attrib):
return None
if any or self.anchor.always_dump:
return self.anchor
return None
def yaml_set_anchor(self, value, always_dump=False):
# type: (Any, bool) -> None
self.anchor.value = value
self.anchor.always_dump = always_dump
class LiteralScalarString(ScalarString):
__slots__ = "comment" # the comment after the | on the first line
style = "|"
def __new__(cls, value, anchor=None):
# type: (Text, Any) -> Any
return ScalarString.__new__(cls, value, anchor=anchor)
PreservedScalarString = LiteralScalarString
class FoldedScalarString(ScalarString):
__slots__ = ("fold_pos", "comment") # the comment after the > on the first line
style = ">"
def __new__(cls, value, anchor=None):
# type: (Text, Any) -> Any
return ScalarString.__new__(cls, value, anchor=anchor)
class SingleQuotedScalarString(ScalarString):
__slots__ = ()
style = "'"
def __new__(cls, value, anchor=None):
# type: (Text, Any) -> Any
return ScalarString.__new__(cls, value, anchor=anchor)
class DoubleQuotedScalarString(ScalarString):
__slots__ = ()
style = '"'
def __new__(cls, value, anchor=None):
# type: (Text, Any) -> Any
return ScalarString.__new__(cls, value, anchor=anchor)
class PlainScalarString(ScalarString):
__slots__ = ()
style = ""
def __new__(cls, value, anchor=None):
# type: (Text, Any) -> Any
return ScalarString.__new__(cls, value, anchor=anchor)
def preserve_literal(s):
# type: (Text) -> Text
return LiteralScalarString(s.replace("\r\n", "\n").replace("\r", "\n"))
def walk_tree(base, map=None):
# type: (Any, Any) -> None
"""
the routine here walks over a simple yaml tree (recursing in
dict values and list items) and converts strings that
have multiple lines to literal scalars
You can also provide an explicit (ordered) mapping for multiple transforms
(first of which is executed):
map = .compat.ordereddict
map['\n'] = preserve_literal
map[':'] = SingleQuotedScalarString
walk_tree(data, map=map)
"""
from .compat import string_types
from .compat import MutableMapping, MutableSequence # type: ignore
if map is None:
map = {"\n": preserve_literal}
if isinstance(base, MutableMapping):
for k in base:
v = base[k] # type: Text
if isinstance(v, string_types):
for ch in map:
if ch in v:
base[k] = map[ch](v)
break
else:
walk_tree(v)
elif isinstance(base, MutableSequence):
for idx, elem in enumerate(base):
if isinstance(elem, string_types):
for ch in map:
if ch in elem: # type: ignore
base[idx] = map[ch](elem)
break
else:
walk_tree(elem)
srsly-release-v2.5.1/srsly/ruamel_yaml/scanner.py 0000775 0000000 0000000 00000216066 14742310675 0022214 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division, unicode_literals
# Scanner produces tokens of the following types:
# STREAM-START
# STREAM-END
# DIRECTIVE(name, value)
# DOCUMENT-START
# DOCUMENT-END
# BLOCK-SEQUENCE-START
# BLOCK-MAPPING-START
# BLOCK-END
# FLOW-SEQUENCE-START
# FLOW-MAPPING-START
# FLOW-SEQUENCE-END
# FLOW-MAPPING-END
# BLOCK-ENTRY
# FLOW-ENTRY
# KEY
# VALUE
# ALIAS(value)
# ANCHOR(value)
# TAG(value)
# SCALAR(value, plain, style)
#
# RoundTripScanner
# COMMENT(value)
#
# Read comments in the Scanner code for more details.
#
from .error import MarkedYAMLError
from .tokens import * # NOQA
from .compat import utf8, unichr, PY3, check_anchorname_char, nprint # NOQA
if False: # MYPY
from typing import Any, Dict, Optional, List, Union, Text # NOQA
from .compat import VersionType # NOQA
__all__ = ["Scanner", "RoundTripScanner", "ScannerError"]
_THE_END = "\n\0\r\x85\u2028\u2029"
_THE_END_SPACE_TAB = " \n\0\t\r\x85\u2028\u2029"
_SPACE_TAB = " \t"
class ScannerError(MarkedYAMLError):
pass
class SimpleKey(object):
# See below simple keys treatment.
def __init__(self, token_number, required, index, line, column, mark):
# type: (Any, Any, int, int, int, Any) -> None
self.token_number = token_number
self.required = required
self.index = index
self.line = line
self.column = column
self.mark = mark
class Scanner(object):
def __init__(self, loader=None):
# type: (Any) -> None
"""Initialize the scanner."""
# It is assumed that Scanner and Reader will have a common descendant.
# Reader do the dirty work of checking for BOM and converting the
# input data to Unicode. It also adds NUL to the end.
#
# Reader supports the following methods
# self.peek(i=0) # peek the next i-th character
# self.prefix(l=1) # peek the next l characters
# self.forward(l=1) # read the next l characters and move the pointer
self.loader = loader
if self.loader is not None and getattr(self.loader, "_scanner", None) is None:
self.loader._scanner = self
self.reset_scanner()
self.first_time = False
self.yaml_version = None # type: Any
@property
def flow_level(self):
# type: () -> int
return len(self.flow_context)
def reset_scanner(self):
# type: () -> None
# Had we reached the end of the stream?
self.done = False
# flow_context is an expanding/shrinking list consisting of '{' and '['
# for each unclosed flow context. If empty list that means block context
self.flow_context = [] # type: List[Text]
# List of processed tokens that are not yet emitted.
self.tokens = [] # type: List[Any]
# Add the STREAM-START token.
self.fetch_stream_start()
# Number of tokens that were emitted through the `get_token` method.
self.tokens_taken = 0
# The current indentation level.
self.indent = -1
# Past indentation levels.
self.indents = [] # type: List[int]
# Variables related to simple keys treatment.
# A simple key is a key that is not denoted by the '?' indicator.
# Example of simple keys:
# ---
# block simple key: value
# ? not a simple key:
# : { flow simple key: value }
# We emit the KEY token before all keys, so when we find a potential
# simple key, we try to locate the corresponding ':' indicator.
# Simple keys should be limited to a single line and 1024 characters.
# Can a simple key start at the current position? A simple key may
# start:
# - at the beginning of the line, not counting indentation spaces
# (in block context),
# - after '{', '[', ',' (in the flow context),
# - after '?', ':', '-' (in the block context).
# In the block context, this flag also signifies if a block collection
# may start at the current position.
self.allow_simple_key = True
# Keep track of possible simple keys. This is a dictionary. The key
# is `flow_level`; there can be no more that one possible simple key
# for each level. The value is a SimpleKey record:
# (token_number, required, index, line, column, mark)
# A simple key may start with ALIAS, ANCHOR, TAG, SCALAR(flow),
# '[', or '{' tokens.
self.possible_simple_keys = {} # type: Dict[Any, Any]
@property
def reader(self):
# type: () -> Any
try:
return self._scanner_reader # type: ignore
except AttributeError:
if hasattr(self.loader, "typ"):
self._scanner_reader = self.loader.reader
else:
self._scanner_reader = self.loader._reader
return self._scanner_reader
@property
def scanner_processing_version(self): # prefix until un-composited
# type: () -> Any
if hasattr(self.loader, "typ"):
return self.loader.resolver.processing_version
return self.loader.processing_version
# Public methods.
def check_token(self, *choices):
# type: (Any) -> bool
# Check if the next token is one of the given types.
while self.need_more_tokens():
self.fetch_more_tokens()
if bool(self.tokens):
if not choices:
return True
for choice in choices:
if isinstance(self.tokens[0], choice):
return True
return False
def peek_token(self):
# type: () -> Any
# Return the next token, but do not delete if from the queue.
while self.need_more_tokens():
self.fetch_more_tokens()
if bool(self.tokens):
return self.tokens[0]
def get_token(self):
# type: () -> Any
# Return the next token.
while self.need_more_tokens():
self.fetch_more_tokens()
if bool(self.tokens):
self.tokens_taken += 1
return self.tokens.pop(0)
# Private methods.
def need_more_tokens(self):
# type: () -> bool
if self.done:
return False
if not self.tokens:
return True
# The current token may be a potential simple key, so we
# need to look further.
self.stale_possible_simple_keys()
if self.next_possible_simple_key() == self.tokens_taken:
return True
return False
def fetch_comment(self, comment):
# type: (Any) -> None
raise NotImplementedError
def fetch_more_tokens(self):
# type: () -> Any
# Eat whitespaces and comments until we reach the next token.
comment = self.scan_to_next_token()
if comment is not None: # never happens for base scanner
return self.fetch_comment(comment)
# Remove obsolete possible simple keys.
self.stale_possible_simple_keys()
# Compare the current indentation and column. It may add some tokens
# and decrease the current indentation level.
self.unwind_indent(self.reader.column)
# Peek the next character.
ch = self.reader.peek()
# Is it the end of stream?
if ch == "\0":
return self.fetch_stream_end()
# Is it a directive?
if ch == "%" and self.check_directive():
return self.fetch_directive()
# Is it the document start?
if ch == "-" and self.check_document_start():
return self.fetch_document_start()
# Is it the document end?
if ch == "." and self.check_document_end():
return self.fetch_document_end()
# TODO: support for BOM within a stream.
# if ch == u'\uFEFF':
# return self.fetch_bom() <-- issue BOMToken
# Note: the order of the following checks is NOT significant.
# Is it the flow sequence start indicator?
if ch == "[":
return self.fetch_flow_sequence_start()
# Is it the flow mapping start indicator?
if ch == "{":
return self.fetch_flow_mapping_start()
# Is it the flow sequence end indicator?
if ch == "]":
return self.fetch_flow_sequence_end()
# Is it the flow mapping end indicator?
if ch == "}":
return self.fetch_flow_mapping_end()
# Is it the flow entry indicator?
if ch == ",":
return self.fetch_flow_entry()
# Is it the block entry indicator?
if ch == "-" and self.check_block_entry():
return self.fetch_block_entry()
# Is it the key indicator?
if ch == "?" and self.check_key():
return self.fetch_key()
# Is it the value indicator?
if ch == ":" and self.check_value():
return self.fetch_value()
# Is it an alias?
if ch == "*":
return self.fetch_alias()
# Is it an anchor?
if ch == "&":
return self.fetch_anchor()
# Is it a tag?
if ch == "!":
return self.fetch_tag()
# Is it a literal scalar?
if ch == "|" and not self.flow_level:
return self.fetch_literal()
# Is it a folded scalar?
if ch == ">" and not self.flow_level:
return self.fetch_folded()
# Is it a single quoted scalar?
if ch == "'":
return self.fetch_single()
# Is it a double quoted scalar?
if ch == '"':
return self.fetch_double()
# It must be a plain scalar then.
if self.check_plain():
return self.fetch_plain()
# No? It's an error. Let's produce a nice error message.
raise ScannerError(
"while scanning for the next token",
None,
"found character %r that cannot start any token" % utf8(ch),
self.reader.get_mark(),
)
# Simple keys treatment.
def next_possible_simple_key(self):
# type: () -> Any
# Return the number of the nearest possible simple key. Actually we
# don't need to loop through the whole dictionary. We may replace it
# with the following code:
# if not self.possible_simple_keys:
# return None
# return self.possible_simple_keys[
# min(self.possible_simple_keys.keys())].token_number
min_token_number = None
for level in self.possible_simple_keys:
key = self.possible_simple_keys[level]
if min_token_number is None or key.token_number < min_token_number:
min_token_number = key.token_number
return min_token_number
def stale_possible_simple_keys(self):
# type: () -> None
# Remove entries that are no longer possible simple keys. According to
# the YAML specification, simple keys
# - should be limited to a single line,
# - should be no longer than 1024 characters.
# Disabling this procedure will allow simple keys of any length and
# height (may cause problems if indentation is broken though).
for level in list(self.possible_simple_keys):
key = self.possible_simple_keys[level]
if key.line != self.reader.line or self.reader.index - key.index > 1024:
if key.required:
raise ScannerError(
"while scanning a simple key",
key.mark,
"could not find expected ':'",
self.reader.get_mark(),
)
del self.possible_simple_keys[level]
def save_possible_simple_key(self):
# type: () -> None
# The next token may start a simple key. We check if it's possible
# and save its position. This function is called for
# ALIAS, ANCHOR, TAG, SCALAR(flow), '[', and '{'.
# Check if a simple key is required at the current position.
required = not self.flow_level and self.indent == self.reader.column
# The next token might be a simple key. Let's save it's number and
# position.
if self.allow_simple_key:
self.remove_possible_simple_key()
token_number = self.tokens_taken + len(self.tokens)
key = SimpleKey(
token_number,
required,
self.reader.index,
self.reader.line,
self.reader.column,
self.reader.get_mark(),
)
self.possible_simple_keys[self.flow_level] = key
def remove_possible_simple_key(self):
# type: () -> None
# Remove the saved possible key position at the current flow level.
if self.flow_level in self.possible_simple_keys:
key = self.possible_simple_keys[self.flow_level]
if key.required:
raise ScannerError(
"while scanning a simple key",
key.mark,
"could not find expected ':'",
self.reader.get_mark(),
)
del self.possible_simple_keys[self.flow_level]
# Indentation functions.
def unwind_indent(self, column):
# type: (Any) -> None
# In flow context, tokens should respect indentation.
# Actually the condition should be `self.indent >= column` according to
# the spec. But this condition will prohibit intuitively correct
# constructions such as
# key : {
# }
# ####
# if self.flow_level and self.indent > column:
# raise ScannerError(None, None,
# "invalid intendation or unclosed '[' or '{'",
# self.reader.get_mark())
# In the flow context, indentation is ignored. We make the scanner less
# restrictive then specification requires.
if bool(self.flow_level):
return
# In block context, we may need to issue the BLOCK-END tokens.
while self.indent > column:
mark = self.reader.get_mark()
self.indent = self.indents.pop()
self.tokens.append(BlockEndToken(mark, mark))
def add_indent(self, column):
# type: (int) -> bool
# Check if we need to increase indentation.
if self.indent < column:
self.indents.append(self.indent)
self.indent = column
return True
return False
# Fetchers.
def fetch_stream_start(self):
# type: () -> None
# We always add STREAM-START as the first token and STREAM-END as the
# last token.
# Read the token.
mark = self.reader.get_mark()
# Add STREAM-START.
self.tokens.append(StreamStartToken(mark, mark, encoding=self.reader.encoding))
def fetch_stream_end(self):
# type: () -> None
# Set the current intendation to -1.
self.unwind_indent(-1)
# Reset simple keys.
self.remove_possible_simple_key()
self.allow_simple_key = False
self.possible_simple_keys = {}
# Read the token.
mark = self.reader.get_mark()
# Add STREAM-END.
self.tokens.append(StreamEndToken(mark, mark))
# The steam is finished.
self.done = True
def fetch_directive(self):
# type: () -> None
# Set the current intendation to -1.
self.unwind_indent(-1)
# Reset simple keys.
self.remove_possible_simple_key()
self.allow_simple_key = False
# Scan and add DIRECTIVE.
self.tokens.append(self.scan_directive())
def fetch_document_start(self):
# type: () -> None
self.fetch_document_indicator(DocumentStartToken)
def fetch_document_end(self):
# type: () -> None
self.fetch_document_indicator(DocumentEndToken)
def fetch_document_indicator(self, TokenClass):
# type: (Any) -> None
# Set the current intendation to -1.
self.unwind_indent(-1)
# Reset simple keys. Note that there could not be a block collection
# after '---'.
self.remove_possible_simple_key()
self.allow_simple_key = False
# Add DOCUMENT-START or DOCUMENT-END.
start_mark = self.reader.get_mark()
self.reader.forward(3)
end_mark = self.reader.get_mark()
self.tokens.append(TokenClass(start_mark, end_mark))
def fetch_flow_sequence_start(self):
# type: () -> None
self.fetch_flow_collection_start(FlowSequenceStartToken, to_push="[")
def fetch_flow_mapping_start(self):
# type: () -> None
self.fetch_flow_collection_start(FlowMappingStartToken, to_push="{")
def fetch_flow_collection_start(self, TokenClass, to_push):
# type: (Any, Text) -> None
# '[' and '{' may start a simple key.
self.save_possible_simple_key()
# Increase the flow level.
self.flow_context.append(to_push)
# Simple keys are allowed after '[' and '{'.
self.allow_simple_key = True
# Add FLOW-SEQUENCE-START or FLOW-MAPPING-START.
start_mark = self.reader.get_mark()
self.reader.forward()
end_mark = self.reader.get_mark()
self.tokens.append(TokenClass(start_mark, end_mark))
def fetch_flow_sequence_end(self):
# type: () -> None
self.fetch_flow_collection_end(FlowSequenceEndToken)
def fetch_flow_mapping_end(self):
# type: () -> None
self.fetch_flow_collection_end(FlowMappingEndToken)
def fetch_flow_collection_end(self, TokenClass):
# type: (Any) -> None
# Reset possible simple key on the current level.
self.remove_possible_simple_key()
# Decrease the flow level.
try:
popped = self.flow_context.pop() # NOQA
except IndexError:
# We must not be in a list or object.
# Defer error handling to the parser.
pass
# No simple keys after ']' or '}'.
self.allow_simple_key = False
# Add FLOW-SEQUENCE-END or FLOW-MAPPING-END.
start_mark = self.reader.get_mark()
self.reader.forward()
end_mark = self.reader.get_mark()
self.tokens.append(TokenClass(start_mark, end_mark))
def fetch_flow_entry(self):
# type: () -> None
# Simple keys are allowed after ','.
self.allow_simple_key = True
# Reset possible simple key on the current level.
self.remove_possible_simple_key()
# Add FLOW-ENTRY.
start_mark = self.reader.get_mark()
self.reader.forward()
end_mark = self.reader.get_mark()
self.tokens.append(FlowEntryToken(start_mark, end_mark))
def fetch_block_entry(self):
# type: () -> None
# Block context needs additional checks.
if not self.flow_level:
# Are we allowed to start a new entry?
if not self.allow_simple_key:
raise ScannerError(
None,
None,
"sequence entries are not allowed here",
self.reader.get_mark(),
)
# We may need to add BLOCK-SEQUENCE-START.
if self.add_indent(self.reader.column):
mark = self.reader.get_mark()
self.tokens.append(BlockSequenceStartToken(mark, mark))
# It's an error for the block entry to occur in the flow context,
# but we let the parser detect this.
else:
pass
# Simple keys are allowed after '-'.
self.allow_simple_key = True
# Reset possible simple key on the current level.
self.remove_possible_simple_key()
# Add BLOCK-ENTRY.
start_mark = self.reader.get_mark()
self.reader.forward()
end_mark = self.reader.get_mark()
self.tokens.append(BlockEntryToken(start_mark, end_mark))
def fetch_key(self):
# type: () -> None
# Block context needs additional checks.
if not self.flow_level:
# Are we allowed to start a key (not nessesary a simple)?
if not self.allow_simple_key:
raise ScannerError(
None,
None,
"mapping keys are not allowed here",
self.reader.get_mark(),
)
# We may need to add BLOCK-MAPPING-START.
if self.add_indent(self.reader.column):
mark = self.reader.get_mark()
self.tokens.append(BlockMappingStartToken(mark, mark))
# Simple keys are allowed after '?' in the block context.
self.allow_simple_key = not self.flow_level
# Reset possible simple key on the current level.
self.remove_possible_simple_key()
# Add KEY.
start_mark = self.reader.get_mark()
self.reader.forward()
end_mark = self.reader.get_mark()
self.tokens.append(KeyToken(start_mark, end_mark))
def fetch_value(self):
# type: () -> None
# Do we determine a simple key?
if self.flow_level in self.possible_simple_keys:
# Add KEY.
key = self.possible_simple_keys[self.flow_level]
del self.possible_simple_keys[self.flow_level]
self.tokens.insert(
key.token_number - self.tokens_taken, KeyToken(key.mark, key.mark)
)
# If this key starts a new block mapping, we need to add
# BLOCK-MAPPING-START.
if not self.flow_level:
if self.add_indent(key.column):
self.tokens.insert(
key.token_number - self.tokens_taken,
BlockMappingStartToken(key.mark, key.mark),
)
# There cannot be two simple keys one after another.
self.allow_simple_key = False
# It must be a part of a complex key.
else:
# Block context needs additional checks.
# (Do we really need them? They will be caught by the parser
# anyway.)
if not self.flow_level:
# We are allowed to start a complex value if and only if
# we can start a simple key.
if not self.allow_simple_key:
raise ScannerError(
None,
None,
"mapping values are not allowed here",
self.reader.get_mark(),
)
# If this value starts a new block mapping, we need to add
# BLOCK-MAPPING-START. It will be detected as an error later by
# the parser.
if not self.flow_level:
if self.add_indent(self.reader.column):
mark = self.reader.get_mark()
self.tokens.append(BlockMappingStartToken(mark, mark))
# Simple keys are allowed after ':' in the block context.
self.allow_simple_key = not self.flow_level
# Reset possible simple key on the current level.
self.remove_possible_simple_key()
# Add VALUE.
start_mark = self.reader.get_mark()
self.reader.forward()
end_mark = self.reader.get_mark()
self.tokens.append(ValueToken(start_mark, end_mark))
def fetch_alias(self):
# type: () -> None
# ALIAS could be a simple key.
self.save_possible_simple_key()
# No simple keys after ALIAS.
self.allow_simple_key = False
# Scan and add ALIAS.
self.tokens.append(self.scan_anchor(AliasToken))
def fetch_anchor(self):
# type: () -> None
# ANCHOR could start a simple key.
self.save_possible_simple_key()
# No simple keys after ANCHOR.
self.allow_simple_key = False
# Scan and add ANCHOR.
self.tokens.append(self.scan_anchor(AnchorToken))
def fetch_tag(self):
# type: () -> None
# TAG could start a simple key.
self.save_possible_simple_key()
# No simple keys after TAG.
self.allow_simple_key = False
# Scan and add TAG.
self.tokens.append(self.scan_tag())
def fetch_literal(self):
# type: () -> None
self.fetch_block_scalar(style="|")
def fetch_folded(self):
# type: () -> None
self.fetch_block_scalar(style=">")
def fetch_block_scalar(self, style):
# type: (Any) -> None
# A simple key may follow a block scalar.
self.allow_simple_key = True
# Reset possible simple key on the current level.
self.remove_possible_simple_key()
# Scan and add SCALAR.
self.tokens.append(self.scan_block_scalar(style))
def fetch_single(self):
# type: () -> None
self.fetch_flow_scalar(style="'")
def fetch_double(self):
# type: () -> None
self.fetch_flow_scalar(style='"')
def fetch_flow_scalar(self, style):
# type: (Any) -> None
# A flow scalar could be a simple key.
self.save_possible_simple_key()
# No simple keys after flow scalars.
self.allow_simple_key = False
# Scan and add SCALAR.
self.tokens.append(self.scan_flow_scalar(style))
def fetch_plain(self):
# type: () -> None
# A plain scalar could be a simple key.
self.save_possible_simple_key()
# No simple keys after plain scalars. But note that `scan_plain` will
# change this flag if the scan is finished at the beginning of the
# line.
self.allow_simple_key = False
# Scan and add SCALAR. May change `allow_simple_key`.
self.tokens.append(self.scan_plain())
# Checkers.
def check_directive(self):
# type: () -> Any
# DIRECTIVE: ^ '%' ...
# The '%' indicator is already checked.
if self.reader.column == 0:
return True
return None
def check_document_start(self):
# type: () -> Any
# DOCUMENT-START: ^ '---' (' '|'\n')
if self.reader.column == 0:
if (
self.reader.prefix(3) == "---"
and self.reader.peek(3) in _THE_END_SPACE_TAB
):
return True
return None
def check_document_end(self):
# type: () -> Any
# DOCUMENT-END: ^ '...' (' '|'\n')
if self.reader.column == 0:
if (
self.reader.prefix(3) == "..."
and self.reader.peek(3) in _THE_END_SPACE_TAB
):
return True
return None
def check_block_entry(self):
# type: () -> Any
# BLOCK-ENTRY: '-' (' '|'\n')
return self.reader.peek(1) in _THE_END_SPACE_TAB
def check_key(self):
# type: () -> Any
# KEY(flow context): '?'
if bool(self.flow_level):
return True
# KEY(block context): '?' (' '|'\n')
return self.reader.peek(1) in _THE_END_SPACE_TAB
def check_value(self):
# type: () -> Any
# VALUE(flow context): ':'
if self.scanner_processing_version == (1, 1):
if bool(self.flow_level):
return True
else:
if bool(self.flow_level):
if self.flow_context[-1] == "[":
if self.reader.peek(1) not in _THE_END_SPACE_TAB:
return False
elif self.tokens and isinstance(self.tokens[-1], ValueToken):
# mapping flow context scanning a value token
if self.reader.peek(1) not in _THE_END_SPACE_TAB:
return False
return True
# VALUE(block context): ':' (' '|'\n')
return self.reader.peek(1) in _THE_END_SPACE_TAB
def check_plain(self):
# type: () -> Any
# A plain scalar may start with any non-space character except:
# '-', '?', ':', ',', '[', ']', '{', '}',
# '#', '&', '*', '!', '|', '>', '\'', '\"',
# '%', '@', '`'.
#
# It may also start with
# '-', '?', ':'
# if it is followed by a non-space character.
#
# Note that we limit the last rule to the block context (except the
# '-' character) because we want the flow context to be space
# independent.
srp = self.reader.peek
ch = srp()
if self.scanner_processing_version == (1, 1):
return ch not in "\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>'\"%@`" or (
srp(1) not in _THE_END_SPACE_TAB
and (ch == "-" or (not self.flow_level and ch in "?:"))
)
# YAML 1.2
if ch not in "\0 \t\r\n\x85\u2028\u2029-?:,[]{}#&*!|>'\"%@`":
# ################### ^ ???
return True
ch1 = srp(1)
if ch == "-" and ch1 not in _THE_END_SPACE_TAB:
return True
if ch == ":" and bool(self.flow_level) and ch1 not in _SPACE_TAB:
return True
return srp(1) not in _THE_END_SPACE_TAB and (
ch == "-" or (not self.flow_level and ch in "?:")
)
# Scanners.
def scan_to_next_token(self):
# type: () -> Any
# We ignore spaces, line breaks and comments.
# If we find a line break in the block context, we set the flag
# `allow_simple_key` on.
# The byte order mark is stripped if it's the first character in the
# stream. We do not yet support BOM inside the stream as the
# specification requires. Any such mark will be considered as a part
# of the document.
#
# TODO: We need to make tab handling rules more sane. A good rule is
# Tabs cannot precede tokens
# BLOCK-SEQUENCE-START, BLOCK-MAPPING-START, BLOCK-END,
# KEY(block), VALUE(block), BLOCK-ENTRY
# So the checking code is
# if :
# self.allow_simple_keys = False
# We also need to add the check for `allow_simple_keys == True` to
# `unwind_indent` before issuing BLOCK-END.
# Scanners for block, flow, and plain scalars need to be modified.
srp = self.reader.peek
srf = self.reader.forward
if self.reader.index == 0 and srp() == "\uFEFF":
srf()
found = False
_the_end = _THE_END
while not found:
while srp() == " ":
srf()
if srp() == "#":
while srp() not in _the_end:
srf()
if self.scan_line_break():
if not self.flow_level:
self.allow_simple_key = True
else:
found = True
return None
def scan_directive(self):
# type: () -> Any
# See the specification for details.
srp = self.reader.peek
srf = self.reader.forward
start_mark = self.reader.get_mark()
srf()
name = self.scan_directive_name(start_mark)
value = None
if name == "YAML":
value = self.scan_yaml_directive_value(start_mark)
end_mark = self.reader.get_mark()
elif name == "TAG":
value = self.scan_tag_directive_value(start_mark)
end_mark = self.reader.get_mark()
else:
end_mark = self.reader.get_mark()
while srp() not in _THE_END:
srf()
self.scan_directive_ignored_line(start_mark)
return DirectiveToken(name, value, start_mark, end_mark)
def scan_directive_name(self, start_mark):
# type: (Any) -> Any
# See the specification for details.
length = 0
srp = self.reader.peek
ch = srp(length)
while "0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_:.":
length += 1
ch = srp(length)
if not length:
raise ScannerError(
"while scanning a directive",
start_mark,
"expected alphabetic or numeric character, but found %r" % utf8(ch),
self.reader.get_mark(),
)
value = self.reader.prefix(length)
self.reader.forward(length)
ch = srp()
if ch not in "\0 \r\n\x85\u2028\u2029":
raise ScannerError(
"while scanning a directive",
start_mark,
"expected alphabetic or numeric character, but found %r" % utf8(ch),
self.reader.get_mark(),
)
return value
def scan_yaml_directive_value(self, start_mark):
# type: (Any) -> Any
# See the specification for details.
srp = self.reader.peek
srf = self.reader.forward
while srp() == " ":
srf()
major = self.scan_yaml_directive_number(start_mark)
if srp() != ".":
raise ScannerError(
"while scanning a directive",
start_mark,
"expected a digit or '.', but found %r" % utf8(srp()),
self.reader.get_mark(),
)
srf()
minor = self.scan_yaml_directive_number(start_mark)
if srp() not in "\0 \r\n\x85\u2028\u2029":
raise ScannerError(
"while scanning a directive",
start_mark,
"expected a digit or ' ', but found %r" % utf8(srp()),
self.reader.get_mark(),
)
self.yaml_version = (major, minor)
return self.yaml_version
def scan_yaml_directive_number(self, start_mark):
# type: (Any) -> Any
# See the specification for details.
srp = self.reader.peek
srf = self.reader.forward
ch = srp()
if not ("0" <= ch <= "9"):
raise ScannerError(
"while scanning a directive",
start_mark,
"expected a digit, but found %r" % utf8(ch),
self.reader.get_mark(),
)
length = 0
while "0" <= srp(length) <= "9":
length += 1
value = int(self.reader.prefix(length))
srf(length)
return value
def scan_tag_directive_value(self, start_mark):
# type: (Any) -> Any
# See the specification for details.
srp = self.reader.peek
srf = self.reader.forward
while srp() == " ":
srf()
handle = self.scan_tag_directive_handle(start_mark)
while srp() == " ":
srf()
prefix = self.scan_tag_directive_prefix(start_mark)
return (handle, prefix)
def scan_tag_directive_handle(self, start_mark):
# type: (Any) -> Any
# See the specification for details.
value = self.scan_tag_handle("directive", start_mark)
ch = self.reader.peek()
if ch != " ":
raise ScannerError(
"while scanning a directive",
start_mark,
"expected ' ', but found %r" % utf8(ch),
self.reader.get_mark(),
)
return value
def scan_tag_directive_prefix(self, start_mark):
# type: (Any) -> Any
# See the specification for details.
value = self.scan_tag_uri("directive", start_mark)
ch = self.reader.peek()
if ch not in "\0 \r\n\x85\u2028\u2029":
raise ScannerError(
"while scanning a directive",
start_mark,
"expected ' ', but found %r" % utf8(ch),
self.reader.get_mark(),
)
return value
def scan_directive_ignored_line(self, start_mark):
# type: (Any) -> None
# See the specification for details.
srp = self.reader.peek
srf = self.reader.forward
while srp() == " ":
srf()
if srp() == "#":
while srp() not in _THE_END:
srf()
ch = srp()
if ch not in _THE_END:
raise ScannerError(
"while scanning a directive",
start_mark,
"expected a comment or a line break, but found %r" % utf8(ch),
self.reader.get_mark(),
)
self.scan_line_break()
def scan_anchor(self, TokenClass):
# type: (Any) -> Any
# The specification does not restrict characters for anchors and
# aliases. This may lead to problems, for instance, the document:
# [ *alias, value ]
# can be interpteted in two ways, as
# [ "value" ]
# and
# [ *alias , "value" ]
# Therefore we restrict aliases to numbers and ASCII letters.
srp = self.reader.peek
start_mark = self.reader.get_mark()
indicator = srp()
if indicator == "*":
name = "alias"
else:
name = "anchor"
self.reader.forward()
length = 0
ch = srp(length)
# while u'0' <= ch <= u'9' or u'A' <= ch <= u'Z' or u'a' <= ch <= u'z' \
# or ch in u'-_':
while check_anchorname_char(ch):
length += 1
ch = srp(length)
if not length:
raise ScannerError(
"while scanning an %s" % (name,),
start_mark,
"expected alphabetic or numeric character, but found %r" % utf8(ch),
self.reader.get_mark(),
)
value = self.reader.prefix(length)
self.reader.forward(length)
# ch1 = ch
# ch = srp() # no need to peek, ch is already set
# assert ch1 == ch
if ch not in "\0 \t\r\n\x85\u2028\u2029?:,[]{}%@`":
raise ScannerError(
"while scanning an %s" % (name,),
start_mark,
"expected alphabetic or numeric character, but found %r" % utf8(ch),
self.reader.get_mark(),
)
end_mark = self.reader.get_mark()
return TokenClass(value, start_mark, end_mark)
def scan_tag(self):
# type: () -> Any
# See the specification for details.
srp = self.reader.peek
start_mark = self.reader.get_mark()
ch = srp(1)
if ch == "<":
handle = None
self.reader.forward(2)
suffix = self.scan_tag_uri("tag", start_mark)
if srp() != ">":
raise ScannerError(
"while parsing a tag",
start_mark,
"expected '>', but found %r" % utf8(srp()),
self.reader.get_mark(),
)
self.reader.forward()
elif ch in _THE_END_SPACE_TAB:
handle = None
suffix = "!"
self.reader.forward()
else:
length = 1
use_handle = False
while ch not in "\0 \r\n\x85\u2028\u2029":
if ch == "!":
use_handle = True
break
length += 1
ch = srp(length)
handle = "!"
if use_handle:
handle = self.scan_tag_handle("tag", start_mark)
else:
handle = "!"
self.reader.forward()
suffix = self.scan_tag_uri("tag", start_mark)
ch = srp()
if ch not in "\0 \r\n\x85\u2028\u2029":
raise ScannerError(
"while scanning a tag",
start_mark,
"expected ' ', but found %r" % utf8(ch),
self.reader.get_mark(),
)
value = (handle, suffix)
end_mark = self.reader.get_mark()
return TagToken(value, start_mark, end_mark)
def scan_block_scalar(self, style, rt=False):
# type: (Any, Optional[bool]) -> Any
# See the specification for details.
srp = self.reader.peek
if style == ">":
folded = True
else:
folded = False
chunks = [] # type: List[Any]
start_mark = self.reader.get_mark()
# Scan the header.
self.reader.forward()
chomping, increment = self.scan_block_scalar_indicators(start_mark)
# block scalar comment e.g. : |+ # comment text
block_scalar_comment = self.scan_block_scalar_ignored_line(start_mark)
# Determine the indentation level and go to the first non-empty line.
min_indent = self.indent + 1
if increment is None:
# no increment and top level, min_indent could be 0
if min_indent < 1 and (
style not in "|>"
or (self.scanner_processing_version == (1, 1))
and getattr(
self.loader,
"top_level_block_style_scalar_no_indent_error_1_1",
False,
)
):
min_indent = 1
breaks, max_indent, end_mark = self.scan_block_scalar_indentation()
indent = max(min_indent, max_indent)
else:
if min_indent < 1:
min_indent = 1
indent = min_indent + increment - 1
breaks, end_mark = self.scan_block_scalar_breaks(indent)
line_break = ""
# Scan the inner part of the block scalar.
while self.reader.column == indent and srp() != "\0":
chunks.extend(breaks)
leading_non_space = srp() not in " \t"
length = 0
while srp(length) not in _THE_END:
length += 1
chunks.append(self.reader.prefix(length))
self.reader.forward(length)
line_break = self.scan_line_break()
breaks, end_mark = self.scan_block_scalar_breaks(indent)
if style in "|>" and min_indent == 0:
# at the beginning of a line, if in block style see if
# end of document/start_new_document
if self.check_document_start() or self.check_document_end():
break
if self.reader.column == indent and srp() != "\0":
# Unfortunately, folding rules are ambiguous.
#
# This is the folding according to the specification:
if rt and folded and line_break == "\n":
chunks.append("\a")
if (
folded
and line_break == "\n"
and leading_non_space
and srp() not in " \t"
):
if not breaks:
chunks.append(" ")
else:
chunks.append(line_break)
# This is Clark Evans's interpretation (also in the spec
# examples):
#
# if folded and line_break == u'\n':
# if not breaks:
# if srp() not in ' \t':
# chunks.append(u' ')
# else:
# chunks.append(line_break)
# else:
# chunks.append(line_break)
else:
break
# Process trailing line breaks. The 'chomping' setting determines
# whether they are included in the value.
trailing = [] # type: List[Any]
if chomping in [None, True]:
chunks.append(line_break)
if chomping is True:
chunks.extend(breaks)
elif chomping in [None, False]:
trailing.extend(breaks)
# We are done.
token = ScalarToken("".join(chunks), False, start_mark, end_mark, style)
if block_scalar_comment is not None:
token.add_pre_comments([block_scalar_comment])
if len(trailing) > 0:
# nprint('trailing 1', trailing) # XXXXX
# Eat whitespaces and comments until we reach the next token.
comment = self.scan_to_next_token()
while comment:
trailing.append(" " * comment[1].column + comment[0])
comment = self.scan_to_next_token()
# Keep track of the trailing whitespace and following comments
# as a comment token, if isn't all included in the actual value.
comment_end_mark = self.reader.get_mark()
comment = CommentToken("".join(trailing), end_mark, comment_end_mark)
token.add_post_comment(comment)
return token
def scan_block_scalar_indicators(self, start_mark):
# type: (Any) -> Any
# See the specification for details.
srp = self.reader.peek
chomping = None
increment = None
ch = srp()
if ch in "+-":
if ch == "+":
chomping = True
else:
chomping = False
self.reader.forward()
ch = srp()
if ch in "0123456789":
increment = int(ch)
if increment == 0:
raise ScannerError(
"while scanning a block scalar",
start_mark,
"expected indentation indicator in the range 1-9, "
"but found 0",
self.reader.get_mark(),
)
self.reader.forward()
elif ch in "0123456789":
increment = int(ch)
if increment == 0:
raise ScannerError(
"while scanning a block scalar",
start_mark,
"expected indentation indicator in the range 1-9, " "but found 0",
self.reader.get_mark(),
)
self.reader.forward()
ch = srp()
if ch in "+-":
if ch == "+":
chomping = True
else:
chomping = False
self.reader.forward()
ch = srp()
if ch not in "\0 \r\n\x85\u2028\u2029":
raise ScannerError(
"while scanning a block scalar",
start_mark,
"expected chomping or indentation indicators, but found %r" % utf8(ch),
self.reader.get_mark(),
)
return chomping, increment
def scan_block_scalar_ignored_line(self, start_mark):
# type: (Any) -> Any
# See the specification for details.
srp = self.reader.peek
srf = self.reader.forward
prefix = ""
comment = None
while srp() == " ":
prefix += srp()
srf()
if srp() == "#":
comment = prefix
while srp() not in _THE_END:
comment += srp()
srf()
ch = srp()
if ch not in _THE_END:
raise ScannerError(
"while scanning a block scalar",
start_mark,
"expected a comment or a line break, but found %r" % utf8(ch),
self.reader.get_mark(),
)
self.scan_line_break()
return comment
def scan_block_scalar_indentation(self):
# type: () -> Any
# See the specification for details.
srp = self.reader.peek
srf = self.reader.forward
chunks = []
max_indent = 0
end_mark = self.reader.get_mark()
while srp() in " \r\n\x85\u2028\u2029":
if srp() != " ":
chunks.append(self.scan_line_break())
end_mark = self.reader.get_mark()
else:
srf()
if self.reader.column > max_indent:
max_indent = self.reader.column
return chunks, max_indent, end_mark
def scan_block_scalar_breaks(self, indent):
# type: (int) -> Any
# See the specification for details.
chunks = []
srp = self.reader.peek
srf = self.reader.forward
end_mark = self.reader.get_mark()
while self.reader.column < indent and srp() == " ":
srf()
while srp() in "\r\n\x85\u2028\u2029":
chunks.append(self.scan_line_break())
end_mark = self.reader.get_mark()
while self.reader.column < indent and srp() == " ":
srf()
return chunks, end_mark
def scan_flow_scalar(self, style):
# type: (Any) -> Any
# See the specification for details.
# Note that we loose indentation rules for quoted scalars. Quoted
# scalars don't need to adhere indentation because " and ' clearly
# mark the beginning and the end of them. Therefore we are less
# restrictive then the specification requires. We only need to check
# that document separators are not included in scalars.
if style == '"':
double = True
else:
double = False
srp = self.reader.peek
chunks = [] # type: List[Any]
start_mark = self.reader.get_mark()
quote = srp()
self.reader.forward()
chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark))
while srp() != quote:
chunks.extend(self.scan_flow_scalar_spaces(double, start_mark))
chunks.extend(self.scan_flow_scalar_non_spaces(double, start_mark))
self.reader.forward()
end_mark = self.reader.get_mark()
return ScalarToken("".join(chunks), False, start_mark, end_mark, style)
ESCAPE_REPLACEMENTS = {
"0": "\0",
"a": "\x07",
"b": "\x08",
"t": "\x09",
"\t": "\x09",
"n": "\x0A",
"v": "\x0B",
"f": "\x0C",
"r": "\x0D",
"e": "\x1B",
" ": "\x20",
'"': '"',
"/": "/", # as per http://www.json.org/
"\\": "\\",
"N": "\x85",
"_": "\xA0",
"L": "\u2028",
"P": "\u2029",
}
ESCAPE_CODES = {"x": 2, "u": 4, "U": 8}
def scan_flow_scalar_non_spaces(self, double, start_mark):
# type: (Any, Any) -> Any
# See the specification for details.
chunks = [] # type: List[Any]
srp = self.reader.peek
srf = self.reader.forward
while True:
length = 0
while srp(length) not in " \n'\"\\\0\t\r\x85\u2028\u2029":
length += 1
if length != 0:
chunks.append(self.reader.prefix(length))
srf(length)
ch = srp()
if not double and ch == "'" and srp(1) == "'":
chunks.append("'")
srf(2)
elif (double and ch == "'") or (not double and ch in '"\\'):
chunks.append(ch)
srf()
elif double and ch == "\\":
srf()
ch = srp()
if ch in self.ESCAPE_REPLACEMENTS:
chunks.append(self.ESCAPE_REPLACEMENTS[ch])
srf()
elif ch in self.ESCAPE_CODES:
length = self.ESCAPE_CODES[ch]
srf()
for k in range(length):
if srp(k) not in "0123456789ABCDEFabcdef":
raise ScannerError(
"while scanning a double-quoted scalar",
start_mark,
"expected escape sequence of %d hexdecimal "
"numbers, but found %r" % (length, utf8(srp(k))),
self.reader.get_mark(),
)
code = int(self.reader.prefix(length), 16)
chunks.append(unichr(code))
srf(length)
elif ch in "\n\r\x85\u2028\u2029":
self.scan_line_break()
chunks.extend(self.scan_flow_scalar_breaks(double, start_mark))
else:
raise ScannerError(
"while scanning a double-quoted scalar",
start_mark,
"found unknown escape character %r" % utf8(ch),
self.reader.get_mark(),
)
else:
return chunks
def scan_flow_scalar_spaces(self, double, start_mark):
# type: (Any, Any) -> Any
# See the specification for details.
srp = self.reader.peek
chunks = []
length = 0
while srp(length) in " \t":
length += 1
whitespaces = self.reader.prefix(length)
self.reader.forward(length)
ch = srp()
if ch == "\0":
raise ScannerError(
"while scanning a quoted scalar",
start_mark,
"found unexpected end of stream",
self.reader.get_mark(),
)
elif ch in "\r\n\x85\u2028\u2029":
line_break = self.scan_line_break()
breaks = self.scan_flow_scalar_breaks(double, start_mark)
if line_break != "\n":
chunks.append(line_break)
elif not breaks:
chunks.append(" ")
chunks.extend(breaks)
else:
chunks.append(whitespaces)
return chunks
def scan_flow_scalar_breaks(self, double, start_mark):
# type: (Any, Any) -> Any
# See the specification for details.
chunks = [] # type: List[Any]
srp = self.reader.peek
srf = self.reader.forward
while True:
# Instead of checking indentation, we check for document
# separators.
prefix = self.reader.prefix(3)
if (prefix == "---" or prefix == "...") and srp(3) in _THE_END_SPACE_TAB:
raise ScannerError(
"while scanning a quoted scalar",
start_mark,
"found unexpected document separator",
self.reader.get_mark(),
)
while srp() in " \t":
srf()
if srp() in "\r\n\x85\u2028\u2029":
chunks.append(self.scan_line_break())
else:
return chunks
def scan_plain(self):
# type: () -> Any
# See the specification for details.
# We add an additional restriction for the flow context:
# plain scalars in the flow context cannot contain ',', ': ' and '?'.
# We also keep track of the `allow_simple_key` flag here.
# Indentation rules are loosed for the flow context.
srp = self.reader.peek
srf = self.reader.forward
chunks = [] # type: List[Any]
start_mark = self.reader.get_mark()
end_mark = start_mark
indent = self.indent + 1
# We allow zero indentation for scalars, but then we need to check for
# document separators at the beginning of the line.
# if indent == 0:
# indent = 1
spaces = [] # type: List[Any]
while True:
length = 0
if srp() == "#":
break
while True:
ch = srp(length)
if ch == ":" and srp(length + 1) not in _THE_END_SPACE_TAB:
pass
elif ch == "?" and self.scanner_processing_version != (1, 1):
pass
elif (
ch in _THE_END_SPACE_TAB
or (
not self.flow_level
and ch == ":"
and srp(length + 1) in _THE_END_SPACE_TAB
)
or (self.flow_level and ch in ",:?[]{}")
):
break
length += 1
# It's not clear what we should do with ':' in the flow context.
if (
self.flow_level
and ch == ":"
and srp(length + 1) not in "\0 \t\r\n\x85\u2028\u2029,[]{}"
):
srf(length)
raise ScannerError(
"while scanning a plain scalar",
start_mark,
"found unexpected ':'",
self.reader.get_mark(),
"Please check "
"http://pyyaml.org/wiki/YAMLColonInFlowContext "
"for details.",
)
if length == 0:
break
self.allow_simple_key = False
chunks.extend(spaces)
chunks.append(self.reader.prefix(length))
srf(length)
end_mark = self.reader.get_mark()
spaces = self.scan_plain_spaces(indent, start_mark)
if (
not spaces
or srp() == "#"
or (not self.flow_level and self.reader.column < indent)
):
break
token = ScalarToken("".join(chunks), True, start_mark, end_mark)
if spaces and spaces[0] == "\n":
# Create a comment token to preserve the trailing line breaks.
comment = CommentToken("".join(spaces) + "\n", start_mark, end_mark)
token.add_post_comment(comment)
return token
def scan_plain_spaces(self, indent, start_mark):
# type: (Any, Any) -> Any
# See the specification for details.
# The specification is really confusing about tabs in plain scalars.
# We just forbid them completely. Do not use tabs in YAML!
srp = self.reader.peek
srf = self.reader.forward
chunks = []
length = 0
while srp(length) in " ":
length += 1
whitespaces = self.reader.prefix(length)
self.reader.forward(length)
ch = srp()
if ch in "\r\n\x85\u2028\u2029":
line_break = self.scan_line_break()
self.allow_simple_key = True
prefix = self.reader.prefix(3)
if (prefix == "---" or prefix == "...") and srp(3) in _THE_END_SPACE_TAB:
return
breaks = []
while srp() in " \r\n\x85\u2028\u2029":
if srp() == " ":
srf()
else:
breaks.append(self.scan_line_break())
prefix = self.reader.prefix(3)
if (prefix == "---" or prefix == "...") and srp(
3
) in _THE_END_SPACE_TAB:
return
if line_break != "\n":
chunks.append(line_break)
elif not breaks:
chunks.append(" ")
chunks.extend(breaks)
elif whitespaces:
chunks.append(whitespaces)
return chunks
def scan_tag_handle(self, name, start_mark):
# type: (Any, Any) -> Any
# See the specification for details.
# For some strange reasons, the specification does not allow '_' in
# tag handles. I have allowed it anyway.
srp = self.reader.peek
ch = srp()
if ch != "!":
raise ScannerError(
"while scanning a %s" % (name,),
start_mark,
"expected '!', but found %r" % utf8(ch),
self.reader.get_mark(),
)
length = 1
ch = srp(length)
if ch != " ":
while (
"0" <= ch <= "9" or "A" <= ch <= "Z" or "a" <= ch <= "z" or ch in "-_"
):
length += 1
ch = srp(length)
if ch != "!":
self.reader.forward(length)
raise ScannerError(
"while scanning a %s" % (name,),
start_mark,
"expected '!', but found %r" % utf8(ch),
self.reader.get_mark(),
)
length += 1
value = self.reader.prefix(length)
self.reader.forward(length)
return value
def scan_tag_uri(self, name, start_mark):
# type: (Any, Any) -> Any
# See the specification for details.
# Note: we do not check if URI is well-formed.
srp = self.reader.peek
chunks = []
length = 0
ch = srp(length)
while (
"0" <= ch <= "9"
or "A" <= ch <= "Z"
or "a" <= ch <= "z"
or ch in "-;/?:@&=+$,_.!~*'()[]%"
or ((self.scanner_processing_version > (1, 1)) and ch == "#")
):
if ch == "%":
chunks.append(self.reader.prefix(length))
self.reader.forward(length)
length = 0
chunks.append(self.scan_uri_escapes(name, start_mark))
else:
length += 1
ch = srp(length)
if length != 0:
chunks.append(self.reader.prefix(length))
self.reader.forward(length)
length = 0
if not chunks:
raise ScannerError(
"while parsing a %s" % (name,),
start_mark,
"expected URI, but found %r" % utf8(ch),
self.reader.get_mark(),
)
return "".join(chunks)
def scan_uri_escapes(self, name, start_mark):
# type: (Any, Any) -> Any
# See the specification for details.
srp = self.reader.peek
srf = self.reader.forward
code_bytes = [] # type: List[Any]
mark = self.reader.get_mark()
while srp() == "%":
srf()
for k in range(2):
if srp(k) not in "0123456789ABCDEFabcdef":
raise ScannerError(
"while scanning a %s" % (name,),
start_mark,
"expected URI escape sequence of 2 hexdecimal numbers,"
" but found %r" % utf8(srp(k)),
self.reader.get_mark(),
)
if PY3:
code_bytes.append(int(self.reader.prefix(2), 16))
else:
code_bytes.append(chr(int(self.reader.prefix(2), 16)))
srf(2)
try:
if PY3:
value = bytes(code_bytes).decode("utf-8")
else:
value = unicode(b"".join(code_bytes), "utf-8")
except UnicodeDecodeError as exc:
raise ScannerError(
"while scanning a %s" % (name,), start_mark, str(exc), mark
)
return value
def scan_line_break(self):
# type: () -> Any
# Transforms:
# '\r\n' : '\n'
# '\r' : '\n'
# '\n' : '\n'
# '\x85' : '\n'
# '\u2028' : '\u2028'
# '\u2029 : '\u2029'
# default : ''
ch = self.reader.peek()
if ch in "\r\n\x85":
if self.reader.prefix(2) == "\r\n":
self.reader.forward(2)
else:
self.reader.forward()
return "\n"
elif ch in "\u2028\u2029":
self.reader.forward()
return ch
return ""
class RoundTripScanner(Scanner):
def check_token(self, *choices):
# type: (Any) -> bool
# Check if the next token is one of the given types.
while self.need_more_tokens():
self.fetch_more_tokens()
self._gather_comments()
if bool(self.tokens):
if not choices:
return True
for choice in choices:
if isinstance(self.tokens[0], choice):
return True
return False
def peek_token(self):
# type: () -> Any
# Return the next token, but do not delete if from the queue.
while self.need_more_tokens():
self.fetch_more_tokens()
self._gather_comments()
if bool(self.tokens):
return self.tokens[0]
return None
def _gather_comments(self):
# type: () -> Any
"""combine multiple comment lines"""
comments = [] # type: List[Any]
if not self.tokens:
return comments
if isinstance(self.tokens[0], CommentToken):
comment = self.tokens.pop(0)
self.tokens_taken += 1
comments.append(comment)
while self.need_more_tokens():
self.fetch_more_tokens()
if not self.tokens:
return comments
if isinstance(self.tokens[0], CommentToken):
self.tokens_taken += 1
comment = self.tokens.pop(0)
# nprint('dropping2', comment)
comments.append(comment)
if len(comments) >= 1:
self.tokens[0].add_pre_comments(comments)
# pull in post comment on e.g. ':'
if not self.done and len(self.tokens) < 2:
self.fetch_more_tokens()
def get_token(self):
# type: () -> Any
# Return the next token.
while self.need_more_tokens():
self.fetch_more_tokens()
self._gather_comments()
if bool(self.tokens):
# nprint('tk', self.tokens)
# only add post comment to single line tokens:
# scalar, value token. FlowXEndToken, otherwise
# hidden streamtokens could get them (leave them and they will be
# pre comments for the next map/seq
if (
len(self.tokens) > 1
and isinstance(
self.tokens[0],
(
ScalarToken,
ValueToken,
FlowSequenceEndToken,
FlowMappingEndToken,
),
)
and isinstance(self.tokens[1], CommentToken)
and self.tokens[0].end_mark.line == self.tokens[1].start_mark.line
):
self.tokens_taken += 1
c = self.tokens.pop(1)
self.fetch_more_tokens()
while len(self.tokens) > 1 and isinstance(self.tokens[1], CommentToken):
self.tokens_taken += 1
c1 = self.tokens.pop(1)
c.value = c.value + (" " * c1.start_mark.column) + c1.value
self.fetch_more_tokens()
self.tokens[0].add_post_comment(c)
elif (
len(self.tokens) > 1
and isinstance(self.tokens[0], ScalarToken)
and isinstance(self.tokens[1], CommentToken)
and self.tokens[0].end_mark.line != self.tokens[1].start_mark.line
):
self.tokens_taken += 1
c = self.tokens.pop(1)
c.value = (
"\n" * (c.start_mark.line - self.tokens[0].end_mark.line)
+ (" " * c.start_mark.column)
+ c.value
)
self.tokens[0].add_post_comment(c)
self.fetch_more_tokens()
while len(self.tokens) > 1 and isinstance(self.tokens[1], CommentToken):
self.tokens_taken += 1
c1 = self.tokens.pop(1)
c.value = c.value + (" " * c1.start_mark.column) + c1.value
self.fetch_more_tokens()
self.tokens_taken += 1
return self.tokens.pop(0)
return None
def fetch_comment(self, comment):
# type: (Any) -> None
value, start_mark, end_mark = comment
while value and value[-1] == " ":
# empty line within indented key context
# no need to update end-mark, that is not used
value = value[:-1]
self.tokens.append(CommentToken(value, start_mark, end_mark))
# scanner
def scan_to_next_token(self):
# type: () -> Any
# We ignore spaces, line breaks and comments.
# If we find a line break in the block context, we set the flag
# `allow_simple_key` on.
# The byte order mark is stripped if it's the first character in the
# stream. We do not yet support BOM inside the stream as the
# specification requires. Any such mark will be considered as a part
# of the document.
#
# TODO: We need to make tab handling rules more sane. A good rule is
# Tabs cannot precede tokens
# BLOCK-SEQUENCE-START, BLOCK-MAPPING-START, BLOCK-END,
# KEY(block), VALUE(block), BLOCK-ENTRY
# So the checking code is
# if :
# self.allow_simple_keys = False
# We also need to add the check for `allow_simple_keys == True` to
# `unwind_indent` before issuing BLOCK-END.
# Scanners for block, flow, and plain scalars need to be modified.
srp = self.reader.peek
srf = self.reader.forward
if self.reader.index == 0 and srp() == "\uFEFF":
srf()
found = False
while not found:
while srp() == " ":
srf()
ch = srp()
if ch == "#":
start_mark = self.reader.get_mark()
comment = ch
srf()
while ch not in _THE_END:
ch = srp()
if ch == "\0": # don't gobble the end-of-stream character
# but add an explicit newline as "YAML processors should terminate
# the stream with an explicit line break
# https://yaml.org/spec/1.2/spec.html#id2780069
comment += "\n"
break
comment += ch
srf()
# gather any blank lines following the comment too
ch = self.scan_line_break()
while len(ch) > 0:
comment += ch
ch = self.scan_line_break()
end_mark = self.reader.get_mark()
if not self.flow_level:
self.allow_simple_key = True
return comment, start_mark, end_mark
if bool(self.scan_line_break()):
start_mark = self.reader.get_mark()
if not self.flow_level:
self.allow_simple_key = True
ch = srp()
if ch == "\n": # empty toplevel lines
start_mark = self.reader.get_mark()
comment = ""
while ch:
ch = self.scan_line_break(empty_line=True)
comment += ch
if srp() == "#":
# empty line followed by indented real comment
comment = comment.rsplit("\n", 1)[0] + "\n"
end_mark = self.reader.get_mark()
return comment, start_mark, end_mark
else:
found = True
return None
def scan_line_break(self, empty_line=False):
# type: (bool) -> Text
# Transforms:
# '\r\n' : '\n'
# '\r' : '\n'
# '\n' : '\n'
# '\x85' : '\n'
# '\u2028' : '\u2028'
# '\u2029 : '\u2029'
# default : ''
ch = self.reader.peek() # type: Text
if ch in "\r\n\x85":
if self.reader.prefix(2) == "\r\n":
self.reader.forward(2)
else:
self.reader.forward()
return "\n"
elif ch in "\u2028\u2029":
self.reader.forward()
return ch
elif empty_line and ch in "\t ":
self.reader.forward()
return ch
return ""
def scan_block_scalar(self, style, rt=True):
# type: (Any, Optional[bool]) -> Any
return Scanner.scan_block_scalar(self, style, rt=rt)
# try:
# import psyco
# psyco.bind(Scanner)
# except ImportError:
# pass
srsly-release-v2.5.1/srsly/ruamel_yaml/serializer.py 0000775 0000000 0000000 00000020523 14742310675 0022723 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
from .error import YAMLError
from .compat import nprint, DBG_NODE, dbg, string_types, nprintf # NOQA
from .util import RegExp
from .events import (
StreamStartEvent,
StreamEndEvent,
MappingStartEvent,
MappingEndEvent,
SequenceStartEvent,
SequenceEndEvent,
AliasEvent,
ScalarEvent,
DocumentStartEvent,
DocumentEndEvent,
)
from .nodes import MappingNode, ScalarNode, SequenceNode
if False: # MYPY
from typing import Any, Dict, Union, Text, Optional # NOQA
from .compat import VersionType # NOQA
__all__ = ["Serializer", "SerializerError"]
class SerializerError(YAMLError):
pass
class Serializer(object):
# 'id' and 3+ numbers, but not 000
ANCHOR_TEMPLATE = u"id%03d"
ANCHOR_RE = RegExp(u"id(?!000$)\\d{3,}")
def __init__(
self,
encoding=None,
explicit_start=None,
explicit_end=None,
version=None,
tags=None,
dumper=None,
):
# type: (Any, Optional[bool], Optional[bool], Optional[VersionType], Any, Any) -> None # NOQA
self.dumper = dumper
if self.dumper is not None:
self.dumper._serializer = self
self.use_encoding = encoding
self.use_explicit_start = explicit_start
self.use_explicit_end = explicit_end
if isinstance(version, string_types):
self.use_version = tuple(map(int, version.split(".")))
else:
self.use_version = version # type: ignore
self.use_tags = tags
self.serialized_nodes = {} # type: Dict[Any, Any]
self.anchors = {} # type: Dict[Any, Any]
self.last_anchor_id = 0
self.closed = None # type: Optional[bool]
self._templated_id = None
@property
def emitter(self):
# type: () -> Any
if hasattr(self.dumper, "typ"):
return self.dumper.emitter
return self.dumper._emitter
@property
def resolver(self):
# type: () -> Any
if hasattr(self.dumper, "typ"):
self.dumper.resolver
return self.dumper._resolver
def open(self):
# type: () -> None
if self.closed is None:
self.emitter.emit(StreamStartEvent(encoding=self.use_encoding))
self.closed = False
elif self.closed:
raise SerializerError("serializer is closed")
else:
raise SerializerError("serializer is already opened")
def close(self):
# type: () -> None
if self.closed is None:
raise SerializerError("serializer is not opened")
elif not self.closed:
self.emitter.emit(StreamEndEvent())
self.closed = True
# def __del__(self):
# self.close()
def serialize(self, node):
# type: (Any) -> None
if dbg(DBG_NODE):
nprint("Serializing nodes")
node.dump()
if self.closed is None:
raise SerializerError("serializer is not opened")
elif self.closed:
raise SerializerError("serializer is closed")
self.emitter.emit(
DocumentStartEvent(
explicit=self.use_explicit_start,
version=self.use_version,
tags=self.use_tags,
)
)
self.anchor_node(node)
self.serialize_node(node, None, None)
self.emitter.emit(DocumentEndEvent(explicit=self.use_explicit_end))
self.serialized_nodes = {}
self.anchors = {}
self.last_anchor_id = 0
def anchor_node(self, node):
# type: (Any) -> None
if node in self.anchors:
if self.anchors[node] is None:
self.anchors[node] = self.generate_anchor(node)
else:
anchor = None
try:
if node.anchor.always_dump:
anchor = node.anchor.value
except: # NOQA
pass
self.anchors[node] = anchor
if isinstance(node, SequenceNode):
for item in node.value:
self.anchor_node(item)
elif isinstance(node, MappingNode):
for key, value in node.value:
self.anchor_node(key)
self.anchor_node(value)
def generate_anchor(self, node):
# type: (Any) -> Any
try:
anchor = node.anchor.value
except: # NOQA
anchor = None
if anchor is None:
self.last_anchor_id += 1
return self.ANCHOR_TEMPLATE % self.last_anchor_id
return anchor
def serialize_node(self, node, parent, index):
# type: (Any, Any, Any) -> None
alias = self.anchors[node]
if node in self.serialized_nodes:
self.emitter.emit(AliasEvent(alias))
else:
self.serialized_nodes[node] = True
self.resolver.descend_resolver(parent, index)
if isinstance(node, ScalarNode):
# here check if the node.tag equals the one that would result from parsing
# if not equal quoting is necessary for strings
detected_tag = self.resolver.resolve(
ScalarNode, node.value, (True, False)
)
default_tag = self.resolver.resolve(
ScalarNode, node.value, (False, True)
)
implicit = (
(node.tag == detected_tag),
(node.tag == default_tag),
node.tag.startswith("tag:yaml.org,2002:"),
)
self.emitter.emit(
ScalarEvent(
alias,
node.tag,
implicit,
node.value,
style=node.style,
comment=node.comment,
)
)
elif isinstance(node, SequenceNode):
implicit = node.tag == self.resolver.resolve(
SequenceNode, node.value, True
)
comment = node.comment
end_comment = None
seq_comment = None
if node.flow_style is True:
if comment: # eol comment on flow style sequence
seq_comment = comment[0]
# comment[0] = None
if comment and len(comment) > 2:
end_comment = comment[2]
else:
end_comment = None
self.emitter.emit(
SequenceStartEvent(
alias,
node.tag,
implicit,
flow_style=node.flow_style,
comment=node.comment,
)
)
index = 0
for item in node.value:
self.serialize_node(item, node, index)
index += 1
self.emitter.emit(SequenceEndEvent(comment=[seq_comment, end_comment]))
elif isinstance(node, MappingNode):
implicit = node.tag == self.resolver.resolve(
MappingNode, node.value, True
)
comment = node.comment
end_comment = None
map_comment = None
if node.flow_style is True:
if comment: # eol comment on flow style sequence
map_comment = comment[0]
# comment[0] = None
if comment and len(comment) > 2:
end_comment = comment[2]
self.emitter.emit(
MappingStartEvent(
alias,
node.tag,
implicit,
flow_style=node.flow_style,
comment=node.comment,
nr_items=len(node.value),
)
)
for key, value in node.value:
self.serialize_node(key, node, None)
self.serialize_node(value, node, key)
self.emitter.emit(MappingEndEvent(comment=[map_comment, end_comment]))
self.resolver.ascend_resolver()
def templated_id(s):
# type: (Text) -> Any
return Serializer.ANCHOR_RE.match(s)
srsly-release-v2.5.1/srsly/ruamel_yaml/timestamp.py 0000775 0000000 0000000 00000001653 14742310675 0022560 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division, unicode_literals
import datetime
import copy
# ToDo: at least on PY3 you could probably attach the tzinfo correctly to the object
# a more complete datetime might be used by safe loading as well
if False: # MYPY
from typing import Any, Dict, Optional, List # NOQA
class TimeStamp(datetime.datetime):
def __init__(self, *args, **kw):
# type: (Any, Any) -> None
self._yaml = dict(t=False, tz=None, delta=0) # type: Dict[Any, Any]
def __new__(cls, *args, **kw): # datetime is immutable
# type: (Any, Any) -> Any
return datetime.datetime.__new__(cls, *args, **kw) # type: ignore
def __deepcopy__(self, memo):
# type: (Any) -> Any
ts = TimeStamp(self.year, self.month, self.day, self.hour, self.minute, self.second)
ts._yaml = copy.deepcopy(self._yaml)
return ts
srsly-release-v2.5.1/srsly/ruamel_yaml/tokens.py 0000775 0000000 0000000 00000016457 14742310675 0022070 0 ustar 00root root 0000000 0000000 # # header
# coding: utf-8
from __future__ import unicode_literals
if False: # MYPY
from typing import Text, Any, Dict, Optional, List # NOQA
from .error import StreamMark # NOQA
SHOWLINES = True
class Token(object):
__slots__ = 'start_mark', 'end_mark', '_comment'
def __init__(self, start_mark, end_mark):
# type: (StreamMark, StreamMark) -> None
self.start_mark = start_mark
self.end_mark = end_mark
def __repr__(self):
# type: () -> Any
# attributes = [key for key in self.__slots__ if not key.endswith('_mark') and
# hasattr('self', key)]
attributes = [key for key in self.__slots__ if not key.endswith('_mark')]
attributes.sort()
arguments = ', '.join(['%s=%r' % (key, getattr(self, key)) for key in attributes])
if SHOWLINES:
try:
arguments += ', line: ' + str(self.start_mark.line)
except: # NOQA
pass
try:
arguments += ', comment: ' + str(self._comment)
except: # NOQA
pass
return '{}({})'.format(self.__class__.__name__, arguments)
def add_post_comment(self, comment):
# type: (Any) -> None
if not hasattr(self, '_comment'):
self._comment = [None, None]
self._comment[0] = comment
def add_pre_comments(self, comments):
# type: (Any) -> None
if not hasattr(self, '_comment'):
self._comment = [None, None]
assert self._comment[1] is None
self._comment[1] = comments
def get_comment(self):
# type: () -> Any
return getattr(self, '_comment', None)
@property
def comment(self):
# type: () -> Any
return getattr(self, '_comment', None)
def move_comment(self, target, empty=False):
# type: (Any, bool) -> Any
"""move a comment from this token to target (normally next token)
used to combine e.g. comments before a BlockEntryToken to the
ScalarToken that follows it
empty is a special for empty values -> comment after key
"""
c = self.comment
if c is None:
return
# don't push beyond last element
if isinstance(target, (StreamEndToken, DocumentStartToken)):
return
delattr(self, '_comment')
tc = target.comment
if not tc: # target comment, just insert
# special for empty value in key: value issue 25
if empty:
c = [c[0], c[1], None, None, c[0]]
target._comment = c
# nprint('mco2:', self, target, target.comment, empty)
return self
if c[0] and tc[0] or c[1] and tc[1]:
raise NotImplementedError('overlap in comment %r %r' % (c, tc))
if c[0]:
tc[0] = c[0]
if c[1]:
tc[1] = c[1]
return self
def split_comment(self):
# type: () -> Any
""" split the post part of a comment, and return it
as comment to be added. Delete second part if [None, None]
abc: # this goes to sequence
# this goes to first element
- first element
"""
comment = self.comment
if comment is None or comment[0] is None:
return None # nothing to do
ret_val = [comment[0], None]
if comment[1] is None:
delattr(self, '_comment')
return ret_val
# class BOMToken(Token):
# id = ''
class DirectiveToken(Token):
__slots__ = 'name', 'value'
id = ''
def __init__(self, name, value, start_mark, end_mark):
# type: (Any, Any, Any, Any) -> None
Token.__init__(self, start_mark, end_mark)
self.name = name
self.value = value
class DocumentStartToken(Token):
__slots__ = ()
id = ''
class DocumentEndToken(Token):
__slots__ = ()
id = ''
class StreamStartToken(Token):
__slots__ = ('encoding',)
id = ''
def __init__(self, start_mark=None, end_mark=None, encoding=None):
# type: (Any, Any, Any) -> None
Token.__init__(self, start_mark, end_mark)
self.encoding = encoding
class StreamEndToken(Token):
__slots__ = ()
id = ''
class BlockSequenceStartToken(Token):
__slots__ = ()
id = ''
class BlockMappingStartToken(Token):
__slots__ = ()
id = ''
class BlockEndToken(Token):
__slots__ = ()
id = ''
class FlowSequenceStartToken(Token):
__slots__ = ()
id = '['
class FlowMappingStartToken(Token):
__slots__ = ()
id = '{'
class FlowSequenceEndToken(Token):
__slots__ = ()
id = ']'
class FlowMappingEndToken(Token):
__slots__ = ()
id = '}'
class KeyToken(Token):
__slots__ = ()
id = '?'
# def x__repr__(self):
# return 'KeyToken({})'.format(
# self.start_mark.buffer[self.start_mark.index:].split(None, 1)[0])
class ValueToken(Token):
__slots__ = ()
id = ':'
class BlockEntryToken(Token):
__slots__ = ()
id = '-'
class FlowEntryToken(Token):
__slots__ = ()
id = ','
class AliasToken(Token):
__slots__ = ('value',)
id = ''
def __init__(self, value, start_mark, end_mark):
# type: (Any, Any, Any) -> None
Token.__init__(self, start_mark, end_mark)
self.value = value
class AnchorToken(Token):
__slots__ = ('value',)
id = ''
def __init__(self, value, start_mark, end_mark):
# type: (Any, Any, Any) -> None
Token.__init__(self, start_mark, end_mark)
self.value = value
class TagToken(Token):
__slots__ = ('value',)
id = ''
def __init__(self, value, start_mark, end_mark):
# type: (Any, Any, Any) -> None
Token.__init__(self, start_mark, end_mark)
self.value = value
class ScalarToken(Token):
__slots__ = 'value', 'plain', 'style'
id = ''
def __init__(self, value, plain, start_mark, end_mark, style=None):
# type: (Any, Any, Any, Any, Any) -> None
Token.__init__(self, start_mark, end_mark)
self.value = value
self.plain = plain
self.style = style
class CommentToken(Token):
__slots__ = 'value', 'pre_done'
id = ''
def __init__(self, value, start_mark, end_mark):
# type: (Any, Any, Any) -> None
Token.__init__(self, start_mark, end_mark)
self.value = value
def reset(self):
# type: () -> None
if hasattr(self, 'pre_done'):
delattr(self, 'pre_done')
def __repr__(self):
# type: () -> Any
v = '{!r}'.format(self.value)
if SHOWLINES:
try:
v += ', line: ' + str(self.start_mark.line)
v += ', col: ' + str(self.start_mark.column)
except: # NOQA
pass
return 'CommentToken({})'.format(v)
def __eq__(self, other):
# type: (Any) -> bool
if self.start_mark != other.start_mark:
return False
if self.end_mark != other.end_mark:
return False
if self.value != other.value:
return False
return True
def __ne__(self, other):
# type: (Any) -> bool
return not self.__eq__(other)
srsly-release-v2.5.1/srsly/ruamel_yaml/util.py 0000775 0000000 0000000 00000013722 14742310675 0021532 0 ustar 00root root 0000000 0000000 # coding: utf-8
"""
some helper functions that might be generally useful
"""
from __future__ import absolute_import, print_function
from functools import partial
import re
from .compat import text_type, binary_type
if False: # MYPY
from typing import Any, Dict, Optional, List, Text # NOQA
from .compat import StreamTextType # NOQA
class LazyEval(object):
"""
Lightweight wrapper around lazily evaluated func(*args, **kwargs).
func is only evaluated when any attribute of its return value is accessed.
Every attribute access is passed through to the wrapped value.
(This only excludes special cases like method-wrappers, e.g., __hash__.)
The sole additional attribute is the lazy_self function which holds the
return value (or, prior to evaluation, func and arguments), in its closure.
"""
def __init__(self, func, *args, **kwargs):
# type: (Any, Any, Any) -> None
def lazy_self():
# type: () -> Any
return_value = func(*args, **kwargs)
object.__setattr__(self, 'lazy_self', lambda: return_value)
return return_value
object.__setattr__(self, 'lazy_self', lazy_self)
def __getattribute__(self, name):
# type: (Any) -> Any
lazy_self = object.__getattribute__(self, 'lazy_self')
if name == 'lazy_self':
return lazy_self
return getattr(lazy_self(), name)
def __setattr__(self, name, value):
# type: (Any, Any) -> None
setattr(self.lazy_self(), name, value)
RegExp = partial(LazyEval, re.compile)
# originally as comment
# https://github.com/pre-commit/pre-commit/pull/211#issuecomment-186466605
# if you use this in your code, I suggest adding a test in your test suite
# that check this routines output against a known piece of your YAML
# before upgrades to this code break your round-tripped YAML
def load_yaml_guess_indent(stream, **kw):
# type: (StreamTextType, Any) -> Any
"""guess the indent and block sequence indent of yaml stream/string
returns round_trip_loaded stream, indent level, block sequence indent
- block sequence indent is the number of spaces before a dash relative to previous indent
- if there are no block sequences, indent is taken from nested mappings, block sequence
indent is unset (None) in that case
"""
from .main import round_trip_load
# load a yaml file guess the indentation, if you use TABs ...
def leading_spaces(l):
# type: (Any) -> int
idx = 0
while idx < len(l) and l[idx] == ' ':
idx += 1
return idx
if isinstance(stream, text_type):
yaml_str = stream # type: Any
elif isinstance(stream, binary_type):
# most likely, but the Reader checks BOM for this
yaml_str = stream.decode('utf-8')
else:
yaml_str = stream.read()
map_indent = None
indent = None # default if not found for some reason
block_seq_indent = None
prev_line_key_only = None
key_indent = 0
for line in yaml_str.splitlines():
rline = line.rstrip()
lline = rline.lstrip()
if lline.startswith('- '):
l_s = leading_spaces(line)
block_seq_indent = l_s - key_indent
idx = l_s + 1
while line[idx] == ' ': # this will end as we rstripped
idx += 1
if line[idx] == '#': # comment after -
continue
indent = idx - key_indent
break
if map_indent is None and prev_line_key_only is not None and rline:
idx = 0
while line[idx] in ' -':
idx += 1
if idx > prev_line_key_only:
map_indent = idx - prev_line_key_only
if rline.endswith(':'):
key_indent = leading_spaces(line)
idx = 0
while line[idx] == ' ': # this will end on ':'
idx += 1
prev_line_key_only = idx
continue
prev_line_key_only = None
if indent is None and map_indent is not None:
indent = map_indent
return round_trip_load(yaml_str, **kw), indent, block_seq_indent
def configobj_walker(cfg):
# type: (Any) -> Any
"""
walks over a ConfigObj (INI file with comments) generating
corresponding YAML output (including comments
"""
from configobj import ConfigObj # type: ignore
assert isinstance(cfg, ConfigObj)
for c in cfg.initial_comment:
if c.strip():
yield c
for s in _walk_section(cfg):
if s.strip():
yield s
for c in cfg.final_comment:
if c.strip():
yield c
def _walk_section(s, level=0):
# type: (Any, int) -> Any
from configobj import Section
assert isinstance(s, Section)
indent = u' ' * level
for name in s.scalars:
for c in s.comments[name]:
yield indent + c.strip()
x = s[name]
if u'\n' in x:
i = indent + u' '
x = u'|\n' + i + x.strip().replace(u'\n', u'\n' + i)
elif ':' in x:
x = u"'" + x.replace(u"'", u"''") + u"'"
line = u'{0}{1}: {2}'.format(indent, name, x)
c = s.inline_comments[name]
if c:
line += u' ' + c
yield line
for name in s.sections:
for c in s.comments[name]:
yield indent + c.strip()
line = u'{0}{1}:'.format(indent, name)
c = s.inline_comments[name]
if c:
line += u' ' + c
yield line
for val in _walk_section(s[name], level=level + 1):
yield val
# def config_obj_2_rt_yaml(cfg):
# from .comments import CommentedMap, CommentedSeq
# from configobj import ConfigObj
# assert isinstance(cfg, ConfigObj)
# #for c in cfg.initial_comment:
# # if c.strip():
# # pass
# cm = CommentedMap()
# for name in s.sections:
# cm[name] = d = CommentedMap()
#
#
# #for c in cfg.final_comment:
# # if c.strip():
# # yield c
# return cm
srsly-release-v2.5.1/srsly/tests/ 0000775 0000000 0000000 00000000000 14742310675 0017026 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/__init__.py 0000664 0000000 0000000 00000000000 14742310675 0021125 0 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/cloudpickle/ 0000775 0000000 0000000 00000000000 14742310675 0021324 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/cloudpickle/__init__.py 0000664 0000000 0000000 00000000000 14742310675 0023423 0 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/cloudpickle/cloudpickle_file_test.py 0000664 0000000 0000000 00000006254 14742310675 0026241 0 ustar 00root root 0000000 0000000 import os
import shutil
import sys
import tempfile
import unittest
import pytest
import srsly.cloudpickle as cloudpickle
from srsly.cloudpickle.compat import pickle
class CloudPickleFileTests(unittest.TestCase):
"""In Cloudpickle, expected behaviour when pickling an opened file
is to send its contents over the wire and seek to the same position."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.tmpfilepath = os.path.join(self.tmpdir, 'testfile')
self.teststring = 'Hello world!'
def tearDown(self):
shutil.rmtree(self.tmpdir)
def test_empty_file(self):
# Empty file
open(self.tmpfilepath, 'w').close()
with open(self.tmpfilepath, 'r') as f:
self.assertEqual('', pickle.loads(cloudpickle.dumps(f)).read())
os.remove(self.tmpfilepath)
def test_closed_file(self):
# Write & close
with open(self.tmpfilepath, 'w') as f:
f.write(self.teststring)
with pytest.raises(pickle.PicklingError) as excinfo:
cloudpickle.dumps(f)
assert "Cannot pickle closed files" in str(excinfo.value)
os.remove(self.tmpfilepath)
def test_r_mode(self):
# Write & close
with open(self.tmpfilepath, 'w') as f:
f.write(self.teststring)
# Open for reading
with open(self.tmpfilepath, 'r') as f:
new_f = pickle.loads(cloudpickle.dumps(f))
self.assertEqual(self.teststring, new_f.read())
os.remove(self.tmpfilepath)
def test_w_mode(self):
with open(self.tmpfilepath, 'w') as f:
f.write(self.teststring)
f.seek(0)
self.assertRaises(pickle.PicklingError,
lambda: cloudpickle.dumps(f))
os.remove(self.tmpfilepath)
def test_plus_mode(self):
# Write, then seek to 0
with open(self.tmpfilepath, 'w+') as f:
f.write(self.teststring)
f.seek(0)
new_f = pickle.loads(cloudpickle.dumps(f))
self.assertEqual(self.teststring, new_f.read())
os.remove(self.tmpfilepath)
def test_seek(self):
# Write, then seek to arbitrary position
with open(self.tmpfilepath, 'w+') as f:
f.write(self.teststring)
f.seek(4)
unpickled = pickle.loads(cloudpickle.dumps(f))
# unpickled StringIO is at position 4
self.assertEqual(4, unpickled.tell())
self.assertEqual(self.teststring[4:], unpickled.read())
# but unpickled StringIO also contained the start
unpickled.seek(0)
self.assertEqual(self.teststring, unpickled.read())
os.remove(self.tmpfilepath)
@pytest.mark.skip(reason="Requires pytest -s to pass")
def test_pickling_special_file_handles(self):
# Warning: if you want to run your tests with nose, add -s option
for out in sys.stdout, sys.stderr: # Regression test for SPARK-3415
self.assertEqual(out, pickle.loads(cloudpickle.dumps(out)))
self.assertRaises(pickle.PicklingError,
lambda: cloudpickle.dumps(sys.stdin))
if __name__ == '__main__':
unittest.main()
srsly-release-v2.5.1/srsly/tests/cloudpickle/cloudpickle_test.py 0000664 0000000 0000000 00000317473 14742310675 0025252 0 ustar 00root root 0000000 0000000 import _collections_abc
import abc
import collections
import base64
import functools
import io
import itertools
import logging
import math
import multiprocessing
from operator import itemgetter, attrgetter
import pickletools
import platform
import random
import re
import shutil
import subprocess
import sys
import tempfile
import textwrap
import types
import unittest
import weakref
import os
import enum
import typing
from functools import wraps
import pytest
try:
# try importing numpy and scipy. These are not hard dependencies and
# tests should be skipped if these modules are not available
import numpy as np
import scipy.special as spp
except (ImportError, RuntimeError):
np = None
spp = None
try:
# Ditto for Tornado
import tornado
except ImportError:
tornado = None
import srsly.cloudpickle as cloudpickle
from srsly.cloudpickle.compat import pickle
from srsly.cloudpickle import register_pickle_by_value
from srsly.cloudpickle import unregister_pickle_by_value
from srsly.cloudpickle import list_registry_pickle_by_value
from srsly.cloudpickle.cloudpickle import _should_pickle_by_reference
from srsly.cloudpickle.cloudpickle import _make_empty_cell, cell_set
from srsly.cloudpickle.cloudpickle import _extract_class_dict, _whichmodule
from srsly.cloudpickle.cloudpickle import _lookup_module_and_qualname
from .testutils import subprocess_pickle_echo
from .testutils import subprocess_pickle_string
from .testutils import assert_run_python_script
from .testutils import subprocess_worker
_TEST_GLOBAL_VARIABLE = "default_value"
_TEST_GLOBAL_VARIABLE2 = "another_value"
class RaiserOnPickle:
def __init__(self, exc):
self.exc = exc
def __reduce__(self):
raise self.exc
def pickle_depickle(obj, protocol=cloudpickle.DEFAULT_PROTOCOL):
"""Helper function to test whether object pickled with cloudpickle can be
depickled with pickle
"""
return pickle.loads(cloudpickle.dumps(obj, protocol=protocol))
def _escape(raw_filepath):
# Ugly hack to embed filepaths in code templates for windows
return raw_filepath.replace("\\", r"\\\\")
def _maybe_remove(list_, item):
try:
list_.remove(item)
except ValueError:
pass
return list_
def test_extract_class_dict():
class A(int):
"""A docstring"""
def method(self):
return "a"
class B:
"""B docstring"""
B_CONSTANT = 42
def method(self):
return "b"
class C(A, B):
C_CONSTANT = 43
def method_c(self):
return "c"
clsdict = _extract_class_dict(C)
assert sorted(clsdict.keys()) == ["C_CONSTANT", "__doc__", "method_c"]
assert clsdict["C_CONSTANT"] == 43
assert clsdict["__doc__"] is None
assert clsdict["method_c"](C()) == C().method_c()
class CloudPickleTest(unittest.TestCase):
protocol = cloudpickle.DEFAULT_PROTOCOL
def setUp(self):
self.tmpdir = tempfile.mkdtemp(prefix="tmp_cloudpickle_test_")
def tearDown(self):
shutil.rmtree(self.tmpdir)
@pytest.mark.skipif(
platform.python_implementation() != "CPython" or
(sys.version_info >= (3, 8, 0) and sys.version_info < (3, 8, 2)),
reason="Underlying bug fixed upstream starting Python 3.8.2")
def test_reducer_override_reference_cycle(self):
# Early versions of Python 3.8 introduced a reference cycle between a
# Pickler and it's reducer_override method. Because a Pickler
# object references every object it has pickled through its memo, this
# cycle prevented the garbage-collection of those external pickled
# objects. See #327 as well as https://bugs.python.org/issue39492
# This bug was fixed in Python 3.8.2, but is still present using
# cloudpickle and Python 3.8.0/1, hence the skipif directive.
class MyClass:
pass
my_object = MyClass()
wr = weakref.ref(my_object)
cloudpickle.dumps(my_object)
del my_object
assert wr() is None, "'del'-ed my_object has not been collected"
def test_itemgetter(self):
d = range(10)
getter = itemgetter(1)
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))
getter = itemgetter(0, 3)
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))
def test_attrgetter(self):
class C:
def __getattr__(self, item):
return item
d = C()
getter = attrgetter("a")
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))
getter = attrgetter("a", "b")
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))
d.e = C()
getter = attrgetter("e.a")
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))
getter = attrgetter("e.a", "e.b")
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))
# Regression test for SPARK-3415
@pytest.mark.skip(reason="Requires pytest -s to pass")
def test_pickling_file_handles(self):
out1 = sys.stderr
out2 = pickle.loads(cloudpickle.dumps(out1, protocol=self.protocol))
self.assertEqual(out1, out2)
def test_func_globals(self):
class Unpicklable:
def __reduce__(self):
raise Exception("not picklable")
global exit
exit = Unpicklable()
self.assertRaises(Exception, lambda: cloudpickle.dumps(
exit, protocol=self.protocol))
def foo():
sys.exit(0)
self.assertTrue("exit" in foo.__code__.co_names)
cloudpickle.dumps(foo)
def test_buffer(self):
try:
buffer_obj = buffer("Hello")
buffer_clone = pickle_depickle(buffer_obj, protocol=self.protocol)
self.assertEqual(buffer_clone, str(buffer_obj))
buffer_obj = buffer("Hello", 2, 3)
buffer_clone = pickle_depickle(buffer_obj, protocol=self.protocol)
self.assertEqual(buffer_clone, str(buffer_obj))
except NameError: # Python 3 does no longer support buffers
pass
def test_memoryview(self):
buffer_obj = memoryview(b"Hello")
self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),
buffer_obj.tobytes())
def test_dict_keys(self):
keys = {"a": 1, "b": 2}.keys()
results = pickle_depickle(keys)
self.assertEqual(results, keys)
assert isinstance(results, _collections_abc.dict_keys)
def test_dict_values(self):
values = {"a": 1, "b": 2}.values()
results = pickle_depickle(values)
self.assertEqual(sorted(results), sorted(values))
assert isinstance(results, _collections_abc.dict_values)
def test_dict_items(self):
items = {"a": 1, "b": 2}.items()
results = pickle_depickle(items)
self.assertEqual(results, items)
assert isinstance(results, _collections_abc.dict_items)
def test_odict_keys(self):
keys = collections.OrderedDict([("a", 1), ("b", 2)]).keys()
results = pickle_depickle(keys)
self.assertEqual(results, keys)
assert type(keys) == type(results)
def test_odict_values(self):
values = collections.OrderedDict([("a", 1), ("b", 2)]).values()
results = pickle_depickle(values)
self.assertEqual(list(results), list(values))
assert type(values) == type(results)
def test_odict_items(self):
items = collections.OrderedDict([("a", 1), ("b", 2)]).items()
results = pickle_depickle(items)
self.assertEqual(results, items)
assert type(items) == type(results)
def test_sliced_and_non_contiguous_memoryview(self):
buffer_obj = memoryview(b"Hello!" * 3)[2:15:2]
self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),
buffer_obj.tobytes())
def test_large_memoryview(self):
buffer_obj = memoryview(b"Hello!" * int(1e7))
self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),
buffer_obj.tobytes())
def test_lambda(self):
self.assertEqual(
pickle_depickle(lambda: 1, protocol=self.protocol)(), 1)
def test_nested_lambdas(self):
a, b = 1, 2
f1 = lambda x: x + a
f2 = lambda x: f1(x) // b
self.assertEqual(pickle_depickle(f2, protocol=self.protocol)(1), 1)
def test_recursive_closure(self):
def f1():
def g():
return g
return g
def f2(base):
def g(n):
return base if n <= 1 else n * g(n - 1)
return g
g1 = pickle_depickle(f1(), protocol=self.protocol)
self.assertEqual(g1(), g1)
g2 = pickle_depickle(f2(2), protocol=self.protocol)
self.assertEqual(g2(5), 240)
def test_closure_none_is_preserved(self):
def f():
"""a function with no closure cells
"""
self.assertTrue(
f.__closure__ is None,
msg='f actually has closure cells!',
)
g = pickle_depickle(f, protocol=self.protocol)
self.assertTrue(
g.__closure__ is None,
msg='g now has closure cells even though f does not',
)
def test_empty_cell_preserved(self):
def f():
if False: # pragma: no cover
cell = None
def g():
cell # NameError, unbound free variable
return g
g1 = f()
with pytest.raises(NameError):
g1()
g2 = pickle_depickle(g1, protocol=self.protocol)
with pytest.raises(NameError):
g2()
def test_unhashable_closure(self):
def f():
s = {1, 2} # mutable set is unhashable
def g():
return len(s)
return g
g = pickle_depickle(f(), protocol=self.protocol)
self.assertEqual(g(), 2)
def test_dynamically_generated_class_that_uses_super(self):
class Base:
def method(self):
return 1
class Derived(Base):
"Derived Docstring"
def method(self):
return super().method() + 1
self.assertEqual(Derived().method(), 2)
# Pickle and unpickle the class.
UnpickledDerived = pickle_depickle(Derived, protocol=self.protocol)
self.assertEqual(UnpickledDerived().method(), 2)
# We have special logic for handling __doc__ because it's a readonly
# attribute on PyPy.
self.assertEqual(UnpickledDerived.__doc__, "Derived Docstring")
# Pickle and unpickle an instance.
orig_d = Derived()
d = pickle_depickle(orig_d, protocol=self.protocol)
self.assertEqual(d.method(), 2)
def test_cycle_in_classdict_globals(self):
class C:
def it_works(self):
return "woohoo!"
C.C_again = C
C.instance_of_C = C()
depickled_C = pickle_depickle(C, protocol=self.protocol)
depickled_instance = pickle_depickle(C())
# Test instance of depickled class.
self.assertEqual(depickled_C().it_works(), "woohoo!")
self.assertEqual(depickled_C.C_again().it_works(), "woohoo!")
self.assertEqual(depickled_C.instance_of_C.it_works(), "woohoo!")
self.assertEqual(depickled_instance.it_works(), "woohoo!")
def test_locally_defined_function_and_class(self):
LOCAL_CONSTANT = 42
def some_function(x, y):
# Make sure the __builtins__ are not broken (see #211)
sum(range(10))
return (x + y) / LOCAL_CONSTANT
# pickle the function definition
self.assertEqual(pickle_depickle(some_function, protocol=self.protocol)(41, 1), 1)
self.assertEqual(pickle_depickle(some_function, protocol=self.protocol)(81, 3), 2)
hidden_constant = lambda: LOCAL_CONSTANT
class SomeClass:
"""Overly complicated class with nested references to symbols"""
def __init__(self, value):
self.value = value
def one(self):
return LOCAL_CONSTANT / hidden_constant()
def some_method(self, x):
return self.one() + some_function(x, 1) + self.value
# pickle the class definition
clone_class = pickle_depickle(SomeClass, protocol=self.protocol)
self.assertEqual(clone_class(1).one(), 1)
self.assertEqual(clone_class(5).some_method(41), 7)
clone_class = subprocess_pickle_echo(SomeClass, protocol=self.protocol)
self.assertEqual(clone_class(5).some_method(41), 7)
# pickle the class instances
self.assertEqual(pickle_depickle(SomeClass(1)).one(), 1)
self.assertEqual(pickle_depickle(SomeClass(5)).some_method(41), 7)
new_instance = subprocess_pickle_echo(SomeClass(5),
protocol=self.protocol)
self.assertEqual(new_instance.some_method(41), 7)
# pickle the method instances
self.assertEqual(pickle_depickle(SomeClass(1).one)(), 1)
self.assertEqual(pickle_depickle(SomeClass(5).some_method)(41), 7)
new_method = subprocess_pickle_echo(SomeClass(5).some_method,
protocol=self.protocol)
self.assertEqual(new_method(41), 7)
def test_partial(self):
partial_obj = functools.partial(min, 1)
partial_clone = pickle_depickle(partial_obj, protocol=self.protocol)
self.assertEqual(partial_clone(4), 1)
@pytest.mark.skipif(platform.python_implementation() == 'PyPy',
reason="Skip numpy and scipy tests on PyPy")
def test_ufunc(self):
# test a numpy ufunc (universal function), which is a C-based function
# that is applied on a numpy array
if np:
# simple ufunc: np.add
self.assertEqual(pickle_depickle(np.add, protocol=self.protocol),
np.add)
else: # skip if numpy is not available
pass
if spp:
# custom ufunc: scipy.special.iv
self.assertEqual(pickle_depickle(spp.iv, protocol=self.protocol),
spp.iv)
else: # skip if scipy is not available
pass
def test_loads_namespace(self):
obj = 1, 2, 3, 4
returned_obj = cloudpickle.loads(cloudpickle.dumps(
obj, protocol=self.protocol))
self.assertEqual(obj, returned_obj)
def test_load_namespace(self):
obj = 1, 2, 3, 4
bio = io.BytesIO()
cloudpickle.dump(obj, bio)
bio.seek(0)
returned_obj = cloudpickle.load(bio)
self.assertEqual(obj, returned_obj)
def test_generator(self):
def some_generator(cnt):
for i in range(cnt):
yield i
gen2 = pickle_depickle(some_generator, protocol=self.protocol)
assert type(gen2(3)) == type(some_generator(3))
assert list(gen2(3)) == list(range(3))
def test_classmethod(self):
class A:
@staticmethod
def test_sm():
return "sm"
@classmethod
def test_cm(cls):
return "cm"
sm = A.__dict__["test_sm"]
cm = A.__dict__["test_cm"]
A.test_sm = pickle_depickle(sm, protocol=self.protocol)
A.test_cm = pickle_depickle(cm, protocol=self.protocol)
self.assertEqual(A.test_sm(), "sm")
self.assertEqual(A.test_cm(), "cm")
def test_bound_classmethod(self):
class A:
@classmethod
def test_cm(cls):
return "cm"
A.test_cm = pickle_depickle(A.test_cm, protocol=self.protocol)
self.assertEqual(A.test_cm(), "cm")
def test_method_descriptors(self):
f = pickle_depickle(str.upper)
self.assertEqual(f('abc'), 'ABC')
def test_instancemethods_without_self(self):
class F:
def f(self, x):
return x + 1
g = pickle_depickle(F.f, protocol=self.protocol)
self.assertEqual(g.__name__, F.f.__name__)
# self.assertEqual(g(F(), 1), 2) # still fails
def test_module(self):
pickle_clone = pickle_depickle(pickle, protocol=self.protocol)
self.assertEqual(pickle, pickle_clone)
def test_dynamic_module(self):
mod = types.ModuleType('mod')
code = '''
x = 1
def f(y):
return x + y
class Foo:
def method(self, x):
return f(x)
'''
exec(textwrap.dedent(code), mod.__dict__)
mod2 = pickle_depickle(mod, protocol=self.protocol)
self.assertEqual(mod.x, mod2.x)
self.assertEqual(mod.f(5), mod2.f(5))
self.assertEqual(mod.Foo().method(5), mod2.Foo().method(5))
if platform.python_implementation() != 'PyPy':
# XXX: this fails with excessive recursion on PyPy.
mod3 = subprocess_pickle_echo(mod, protocol=self.protocol)
self.assertEqual(mod.x, mod3.x)
self.assertEqual(mod.f(5), mod3.f(5))
self.assertEqual(mod.Foo().method(5), mod3.Foo().method(5))
# Test dynamic modules when imported back are singletons
mod1, mod2 = pickle_depickle([mod, mod])
self.assertEqual(id(mod1), id(mod2))
# Ensure proper pickling of mod's functions when module "looks" like a
# file-backed module even though it is not:
try:
sys.modules['mod'] = mod
depickled_f = pickle_depickle(mod.f, protocol=self.protocol)
self.assertEqual(mod.f(5), depickled_f(5))
finally:
sys.modules.pop('mod', None)
def test_module_locals_behavior(self):
# Makes sure that a local function defined in another module is
# correctly serialized. This notably checks that the globals are
# accessible and that there is no issue with the builtins (see #211)
pickled_func_path = os.path.join(self.tmpdir, 'local_func_g.pkl')
child_process_script = '''
from srsly.cloudpickle.compat import pickle
import gc
with open("{pickled_func_path}", 'rb') as f:
func = pickle.load(f)
assert func(range(10)) == 45
'''
child_process_script = child_process_script.format(
pickled_func_path=_escape(pickled_func_path))
try:
from srsly.tests.cloudpickle.testutils import make_local_function
g = make_local_function()
with open(pickled_func_path, 'wb') as f:
cloudpickle.dump(g, f, protocol=self.protocol)
assert_run_python_script(textwrap.dedent(child_process_script))
finally:
os.unlink(pickled_func_path)
def test_dynamic_module_with_unpicklable_builtin(self):
# Reproducer of https://github.com/cloudpipe/cloudpickle/issues/316
# Some modules such as scipy inject some unpicklable objects into the
# __builtins__ module, which appears in every module's __dict__ under
# the '__builtins__' key. In such cases, cloudpickle used to fail
# when pickling dynamic modules.
class UnpickleableObject:
def __reduce__(self):
raise ValueError('Unpicklable object')
mod = types.ModuleType("mod")
exec('f = lambda x: abs(x)', mod.__dict__)
assert mod.f(-1) == 1
assert '__builtins__' in mod.__dict__
unpicklable_obj = UnpickleableObject()
with pytest.raises(ValueError):
cloudpickle.dumps(unpicklable_obj)
# Emulate the behavior of scipy by injecting an unpickleable object
# into mod's builtins.
# The __builtins__ entry of mod's __dict__ can either be the
# __builtins__ module, or the __builtins__ module's __dict__. #316
# happens only in the latter case.
if isinstance(mod.__dict__['__builtins__'], dict):
mod.__dict__['__builtins__']['unpickleable_obj'] = unpicklable_obj
elif isinstance(mod.__dict__['__builtins__'], types.ModuleType):
mod.__dict__['__builtins__'].unpickleable_obj = unpicklable_obj
depickled_mod = pickle_depickle(mod, protocol=self.protocol)
assert '__builtins__' in depickled_mod.__dict__
if isinstance(depickled_mod.__dict__['__builtins__'], dict):
assert "abs" in depickled_mod.__builtins__
elif isinstance(
depickled_mod.__dict__['__builtins__'], types.ModuleType):
assert hasattr(depickled_mod.__builtins__, "abs")
assert depickled_mod.f(-1) == 1
# Additional check testing that the issue #425 is fixed: without the
# fix for #425, `mod.f` would not have access to `__builtins__`, and
# thus calling `mod.f(-1)` (which relies on the `abs` builtin) would
# fail.
assert mod.f(-1) == 1
def test_load_dynamic_module_in_grandchild_process(self):
# Make sure that when loaded, a dynamic module preserves its dynamic
# property. Otherwise, this will lead to an ImportError if pickled in
# the child process and reloaded in another one.
# We create a new dynamic module
mod = types.ModuleType('mod')
code = '''
x = 1
'''
exec(textwrap.dedent(code), mod.__dict__)
# This script will be ran in a separate child process. It will import
# the pickled dynamic module, and then re-pickle it under a new name.
# Finally, it will create a child process that will load the re-pickled
# dynamic module.
parent_process_module_file = os.path.join(
self.tmpdir, 'dynamic_module_from_parent_process.pkl')
child_process_module_file = os.path.join(
self.tmpdir, 'dynamic_module_from_child_process.pkl')
child_process_script = '''
from srsly.cloudpickle.compat import pickle
import textwrap
import srsly.cloudpickle as cloudpickle
from srsly.tests.cloudpickle.testutils import assert_run_python_script
child_of_child_process_script = {child_of_child_process_script}
with open('{parent_process_module_file}', 'rb') as f:
mod = pickle.load(f)
with open('{child_process_module_file}', 'wb') as f:
cloudpickle.dump(mod, f, protocol={protocol})
assert_run_python_script(textwrap.dedent(child_of_child_process_script))
'''
# The script ran by the process created by the child process
child_of_child_process_script = """ '''
from srsly.cloudpickle.compat import pickle
with open('{child_process_module_file}','rb') as fid:
mod = pickle.load(fid)
''' """
# Filling the two scripts with the pickled modules filepaths and,
# for the first child process, the script to be executed by its
# own child process.
child_of_child_process_script = child_of_child_process_script.format(
child_process_module_file=child_process_module_file)
child_process_script = child_process_script.format(
parent_process_module_file=_escape(parent_process_module_file),
child_process_module_file=_escape(child_process_module_file),
child_of_child_process_script=_escape(child_of_child_process_script),
protocol=self.protocol)
try:
with open(parent_process_module_file, 'wb') as fid:
cloudpickle.dump(mod, fid, protocol=self.protocol)
assert_run_python_script(textwrap.dedent(child_process_script))
finally:
# Remove temporary created files
if os.path.exists(parent_process_module_file):
os.unlink(parent_process_module_file)
if os.path.exists(child_process_module_file):
os.unlink(child_process_module_file)
def test_correct_globals_import(self):
def nested_function(x):
return x + 1
def unwanted_function(x):
return math.exp(x)
def my_small_function(x, y):
return nested_function(x) + y
b = cloudpickle.dumps(my_small_function, protocol=self.protocol)
# Make sure that the pickle byte string only includes the definition
# of my_small_function and its dependency nested_function while
# extra functions and modules such as unwanted_function and the math
# module are not included so as to keep the pickle payload as
# lightweight as possible.
assert b'my_small_function' in b
assert b'nested_function' in b
assert b'unwanted_function' not in b
assert b'math' not in b
def test_module_importability(self):
pytest.importorskip("_cloudpickle_testpkg")
from srsly.cloudpickle.compat import pickle
import os.path
import collections
import collections.abc
assert _should_pickle_by_reference(pickle)
assert _should_pickle_by_reference(os.path) # fake (aliased) module
assert _should_pickle_by_reference(collections) # package
assert _should_pickle_by_reference(collections.abc) # module in package
dynamic_module = types.ModuleType('dynamic_module')
assert not _should_pickle_by_reference(dynamic_module)
if platform.python_implementation() == 'PyPy':
import _codecs
assert _should_pickle_by_reference(_codecs)
# #354: Check that modules created dynamically during the import of
# their parent modules are considered importable by cloudpickle.
# See the mod_with_dynamic_submodule documentation for more
# details of this use case.
import _cloudpickle_testpkg.mod.dynamic_submodule as m
assert _should_pickle_by_reference(m)
assert pickle_depickle(m, protocol=self.protocol) is m
# Check for similar behavior for a module that cannot be imported by
# attribute lookup.
from _cloudpickle_testpkg.mod import dynamic_submodule_two as m2
# Note: import _cloudpickle_testpkg.mod.dynamic_submodule_two as m2
# works only for Python 3.7+
assert _should_pickle_by_reference(m2)
assert pickle_depickle(m2, protocol=self.protocol) is m2
# Submodule_three is a dynamic module only importable via module lookup
with pytest.raises(ImportError):
import _cloudpickle_testpkg.mod.submodule_three # noqa
from _cloudpickle_testpkg.mod import submodule_three as m3
assert not _should_pickle_by_reference(m3)
# This module cannot be pickled using attribute lookup (as it does not
# have a `__module__` attribute like classes and functions.
assert not hasattr(m3, '__module__')
depickled_m3 = pickle_depickle(m3, protocol=self.protocol)
assert depickled_m3 is not m3
assert m3.f(1) == depickled_m3.f(1)
# Do the same for an importable dynamic submodule inside a dynamic
# module inside a file-backed module.
import _cloudpickle_testpkg.mod.dynamic_submodule.dynamic_subsubmodule as sm # noqa
assert _should_pickle_by_reference(sm)
assert pickle_depickle(sm, protocol=self.protocol) is sm
expected = "cannot check importability of object instances"
with pytest.raises(TypeError, match=expected):
_should_pickle_by_reference(object())
def test_Ellipsis(self):
self.assertEqual(Ellipsis,
pickle_depickle(Ellipsis, protocol=self.protocol))
def test_NotImplemented(self):
ExcClone = pickle_depickle(NotImplemented, protocol=self.protocol)
self.assertEqual(NotImplemented, ExcClone)
def test_NoneType(self):
res = pickle_depickle(type(None), protocol=self.protocol)
self.assertEqual(type(None), res)
def test_EllipsisType(self):
res = pickle_depickle(type(Ellipsis), protocol=self.protocol)
self.assertEqual(type(Ellipsis), res)
def test_NotImplementedType(self):
res = pickle_depickle(type(NotImplemented), protocol=self.protocol)
self.assertEqual(type(NotImplemented), res)
def test_builtin_function(self):
# Note that builtin_function_or_method are special-cased by cloudpickle
# only in python2.
# builtin function from the __builtin__ module
assert pickle_depickle(zip, protocol=self.protocol) is zip
from os import mkdir
# builtin function from a "regular" module
assert pickle_depickle(mkdir, protocol=self.protocol) is mkdir
def test_builtin_type_constructor(self):
# This test makes sure that cloudpickling builtin-type
# constructors works for all python versions/implementation.
# pickle_depickle some builtin methods of the __builtin__ module
for t in list, tuple, set, frozenset, dict, object:
cloned_new = pickle_depickle(t.__new__, protocol=self.protocol)
assert isinstance(cloned_new(t), t)
# The next 4 tests cover all cases into which builtin python methods can
# appear.
# There are 4 kinds of method: 'classic' methods, classmethods,
# staticmethods and slotmethods. They will appear under different types
# depending on whether they are called from the __dict__ of their
# class, their class itself, or an instance of their class. This makes
# 12 total combinations.
# This discussion and the following tests are relevant for the CPython
# implementation only. In PyPy, there is no builtin method or builtin
# function types/flavours. The only way into which a builtin method can be
# identified is with it's builtin-code __code__ attribute.
def test_builtin_classicmethod(self):
obj = 1.5 # float object
bound_classicmethod = obj.hex # builtin_function_or_method
unbound_classicmethod = type(obj).hex # method_descriptor
clsdict_classicmethod = type(obj).__dict__['hex'] # method_descriptor
assert unbound_classicmethod is clsdict_classicmethod
depickled_bound_meth = pickle_depickle(
bound_classicmethod, protocol=self.protocol)
depickled_unbound_meth = pickle_depickle(
unbound_classicmethod, protocol=self.protocol)
depickled_clsdict_meth = pickle_depickle(
clsdict_classicmethod, protocol=self.protocol)
# No identity on the bound methods they are bound to different float
# instances
assert depickled_bound_meth() == bound_classicmethod()
assert depickled_unbound_meth is unbound_classicmethod
assert depickled_clsdict_meth is clsdict_classicmethod
@pytest.mark.skipif(
(platform.machine() == "aarch64" and sys.version_info[:2] >= (3, 10))
or platform.python_implementation() == "PyPy"
or (sys.version_info[:2] == (3, 10) and sys.version_info >= (3, 10, 8))
# Skipping tests on 3.11 due to https://github.com/cloudpipe/cloudpickle/pull/486.
or sys.version_info[:2] >= (3, 11),
reason="Fails on aarch64 + python 3.10+ in cibuildwheel, currently unable to replicate failure elsewhere; fails sometimes for pypy on conda-forge; fails for python 3.10.8+ and 3.11+")
def test_builtin_classmethod(self):
obj = 1.5 # float object
bound_clsmethod = obj.fromhex # builtin_function_or_method
unbound_clsmethod = type(obj).fromhex # builtin_function_or_method
clsdict_clsmethod = type(
obj).__dict__['fromhex'] # classmethod_descriptor
depickled_bound_meth = pickle_depickle(
bound_clsmethod, protocol=self.protocol)
depickled_unbound_meth = pickle_depickle(
unbound_clsmethod, protocol=self.protocol)
depickled_clsdict_meth = pickle_depickle(
clsdict_clsmethod, protocol=self.protocol)
# float.fromhex takes a string as input.
arg = "0x1"
# Identity on both the bound and the unbound methods cannot be
# tested: the bound methods are bound to different objects, and the
# unbound methods are actually recreated at each call.
assert depickled_bound_meth(arg) == bound_clsmethod(arg)
assert depickled_unbound_meth(arg) == unbound_clsmethod(arg)
if platform.python_implementation() == 'CPython':
# Roundtripping a classmethod_descriptor results in a
# builtin_function_or_method (CPython upstream issue).
assert depickled_clsdict_meth(arg) == clsdict_clsmethod(float, arg)
if platform.python_implementation() == 'PyPy':
# builtin-classmethods are simple classmethod in PyPy (not
# callable). We test equality of types and the functionality of the
# __func__ attribute instead. We do not test the the identity of
# the functions as __func__ attributes of classmethods are not
# pickleable and must be reconstructed at depickling time.
assert type(depickled_clsdict_meth) == type(clsdict_clsmethod)
assert depickled_clsdict_meth.__func__(
float, arg) == clsdict_clsmethod.__func__(float, arg)
def test_builtin_slotmethod(self):
obj = 1.5 # float object
bound_slotmethod = obj.__repr__ # method-wrapper
unbound_slotmethod = type(obj).__repr__ # wrapper_descriptor
clsdict_slotmethod = type(obj).__dict__['__repr__'] # ditto
depickled_bound_meth = pickle_depickle(
bound_slotmethod, protocol=self.protocol)
depickled_unbound_meth = pickle_depickle(
unbound_slotmethod, protocol=self.protocol)
depickled_clsdict_meth = pickle_depickle(
clsdict_slotmethod, protocol=self.protocol)
# No identity tests on the bound slotmethod are they are bound to
# different float instances
assert depickled_bound_meth() == bound_slotmethod()
assert depickled_unbound_meth is unbound_slotmethod
assert depickled_clsdict_meth is clsdict_slotmethod
@pytest.mark.skipif(
platform.python_implementation() == "PyPy",
reason="No known staticmethod example in the pypy stdlib")
def test_builtin_staticmethod(self):
obj = "foo" # str object
bound_staticmethod = obj.maketrans # builtin_function_or_method
unbound_staticmethod = type(obj).maketrans # ditto
clsdict_staticmethod = type(obj).__dict__['maketrans'] # staticmethod
assert bound_staticmethod is unbound_staticmethod
depickled_bound_meth = pickle_depickle(
bound_staticmethod, protocol=self.protocol)
depickled_unbound_meth = pickle_depickle(
unbound_staticmethod, protocol=self.protocol)
depickled_clsdict_meth = pickle_depickle(
clsdict_staticmethod, protocol=self.protocol)
assert depickled_bound_meth is bound_staticmethod
assert depickled_unbound_meth is unbound_staticmethod
# staticmethod objects are recreated at depickling time, but the
# underlying __func__ object is pickled by attribute.
assert depickled_clsdict_meth.__func__ is clsdict_staticmethod.__func__
type(depickled_clsdict_meth) is type(clsdict_staticmethod)
@pytest.mark.skipif(tornado is None,
reason="test needs Tornado installed")
def test_tornado_coroutine(self):
# Pickling a locally defined coroutine function
from tornado import gen, ioloop
@gen.coroutine
def f(x, y):
yield gen.sleep(x)
raise gen.Return(y + 1)
@gen.coroutine
def g(y):
res = yield f(0.01, y)
raise gen.Return(res + 1)
data = cloudpickle.dumps([g, g], protocol=self.protocol)
f = g = None
g2, g3 = pickle.loads(data)
self.assertTrue(g2 is g3)
loop = ioloop.IOLoop.current()
res = loop.run_sync(functools.partial(g2, 5))
self.assertEqual(res, 7)
@pytest.mark.skipif(
(3, 11, 0, 'beta') <= sys.version_info < (3, 11, 0, 'beta', 4),
reason="https://github.com/python/cpython/issues/92932"
)
def test_extended_arg(self):
# Functions with more than 65535 global vars prefix some global
# variable references with the EXTENDED_ARG opcode.
nvars = 65537 + 258
names = ['g%d' % i for i in range(1, nvars)]
r = random.Random(42)
d = {name: r.randrange(100) for name in names}
# def f(x):
# x = g1, g2, ...
# return zlib.crc32(bytes(bytearray(x)))
code = """
import zlib
def f():
x = {tup}
return zlib.crc32(bytes(bytearray(x)))
""".format(tup=', '.join(names))
exec(textwrap.dedent(code), d, d)
f = d['f']
res = f()
data = cloudpickle.dumps([f, f], protocol=self.protocol)
d = f = None
f2, f3 = pickle.loads(data)
self.assertTrue(f2 is f3)
self.assertEqual(f2(), res)
def test_submodule(self):
# Function that refers (by attribute) to a sub-module of a package.
# Choose any module NOT imported by __init__ of its parent package
# examples in standard library include:
# - http.cookies, unittest.mock, curses.textpad, xml.etree.ElementTree
global xml # imitate performing this import at top of file
import xml.etree.ElementTree
def example():
x = xml.etree.ElementTree.Comment # potential AttributeError
s = cloudpickle.dumps(example, protocol=self.protocol)
# refresh the environment, i.e., unimport the dependency
del xml
for item in list(sys.modules):
if item.split('.')[0] == 'xml':
del sys.modules[item]
# deserialise
f = pickle.loads(s)
f() # perform test for error
def test_submodule_closure(self):
# Same as test_submodule except the package is not a global
def scope():
import xml.etree.ElementTree
def example():
x = xml.etree.ElementTree.Comment # potential AttributeError
return example
example = scope()
s = cloudpickle.dumps(example, protocol=self.protocol)
# refresh the environment (unimport dependency)
for item in list(sys.modules):
if item.split('.')[0] == 'xml':
del sys.modules[item]
f = cloudpickle.loads(s)
f() # test
def test_multiprocess(self):
# running a function pickled by another process (a la dask.distributed)
def scope():
def example():
x = xml.etree.ElementTree.Comment
return example
global xml
import xml.etree.ElementTree
example = scope()
s = cloudpickle.dumps(example, protocol=self.protocol)
# choose "subprocess" rather than "multiprocessing" because the latter
# library uses fork to preserve the parent environment.
command = ("import base64; "
"from srsly.cloudpickle.compat import pickle; "
"pickle.loads(base64.b32decode('" +
base64.b32encode(s).decode('ascii') +
"'))()")
assert not subprocess.call([sys.executable, '-c', command])
def test_import(self):
# like test_multiprocess except subpackage modules referenced directly
# (unlike test_submodule)
global etree
def scope():
import xml.etree as foobar
def example():
x = etree.Comment
x = foobar.ElementTree
return example
example = scope()
import xml.etree.ElementTree as etree
s = cloudpickle.dumps(example, protocol=self.protocol)
command = ("import base64; "
"from srsly.cloudpickle.compat import pickle; "
"pickle.loads(base64.b32decode('" +
base64.b32encode(s).decode('ascii') +
"'))()")
assert not subprocess.call([sys.executable, '-c', command])
def test_multiprocessing_lock_raises(self):
lock = multiprocessing.Lock()
with pytest.raises(RuntimeError, match="only be shared between processes through inheritance"):
cloudpickle.dumps(lock)
def test_cell_manipulation(self):
cell = _make_empty_cell()
with pytest.raises(ValueError):
cell.cell_contents
ob = object()
cell_set(cell, ob)
self.assertTrue(
cell.cell_contents is ob,
msg='cell contents not set correctly',
)
def check_logger(self, name):
logger = logging.getLogger(name)
pickled = pickle_depickle(logger, protocol=self.protocol)
self.assertTrue(pickled is logger, (pickled, logger))
dumped = cloudpickle.dumps(logger)
code = """if 1:
import base64, srsly.cloudpickle as cloudpickle, logging
logging.basicConfig(level=logging.INFO)
logger = cloudpickle.loads(base64.b32decode(b'{}'))
logger.info('hello')
""".format(base64.b32encode(dumped).decode('ascii'))
proc = subprocess.Popen([sys.executable, "-W ignore", "-c", code],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
out, _ = proc.communicate()
self.assertEqual(proc.wait(), 0)
self.assertEqual(out.strip().decode(),
f'INFO:{logger.name}:hello')
def test_logger(self):
# logging.RootLogger object
self.check_logger(None)
# logging.Logger object
self.check_logger('cloudpickle.dummy_test_logger')
def test_getset_descriptor(self):
assert isinstance(float.real, types.GetSetDescriptorType)
depickled_descriptor = pickle_depickle(float.real)
self.assertIs(depickled_descriptor, float.real)
def test_abc_cache_not_pickled(self):
# cloudpickle issue #302: make sure that cloudpickle does not pickle
# the caches populated during instance/subclass checks of abc.ABCMeta
# instances.
MyClass = abc.ABCMeta('MyClass', (), {})
class MyUnrelatedClass:
pass
class MyRelatedClass:
pass
MyClass.register(MyRelatedClass)
assert not issubclass(MyUnrelatedClass, MyClass)
assert issubclass(MyRelatedClass, MyClass)
s = cloudpickle.dumps(MyClass)
assert b"MyUnrelatedClass" not in s
assert b"MyRelatedClass" in s
depickled_class = cloudpickle.loads(s)
assert not issubclass(MyUnrelatedClass, depickled_class)
assert issubclass(MyRelatedClass, depickled_class)
def test_abc(self):
class AbstractClass(abc.ABC):
@abc.abstractmethod
def some_method(self):
"""A method"""
@classmethod
@abc.abstractmethod
def some_classmethod(cls):
"""A classmethod"""
@staticmethod
@abc.abstractmethod
def some_staticmethod():
"""A staticmethod"""
@property
@abc.abstractmethod
def some_property():
"""A property"""
class ConcreteClass(AbstractClass):
def some_method(self):
return 'it works!'
@classmethod
def some_classmethod(cls):
assert cls == ConcreteClass
return 'it works!'
@staticmethod
def some_staticmethod():
return 'it works!'
@property
def some_property(self):
return 'it works!'
# This abstract class is locally defined so we can safely register
# tuple in it to verify the unpickled class also register tuple.
AbstractClass.register(tuple)
concrete_instance = ConcreteClass()
depickled_base = pickle_depickle(AbstractClass, protocol=self.protocol)
depickled_class = pickle_depickle(ConcreteClass,
protocol=self.protocol)
depickled_instance = pickle_depickle(concrete_instance)
assert issubclass(tuple, AbstractClass)
assert issubclass(tuple, depickled_base)
self.assertEqual(depickled_class().some_method(), 'it works!')
self.assertEqual(depickled_instance.some_method(), 'it works!')
self.assertEqual(depickled_class.some_classmethod(), 'it works!')
self.assertEqual(depickled_instance.some_classmethod(), 'it works!')
self.assertEqual(depickled_class().some_staticmethod(), 'it works!')
self.assertEqual(depickled_instance.some_staticmethod(), 'it works!')
self.assertEqual(depickled_class().some_property, 'it works!')
self.assertEqual(depickled_instance.some_property, 'it works!')
self.assertRaises(TypeError, depickled_base)
class DepickledBaseSubclass(depickled_base):
def some_method(self):
return 'it works for realz!'
@classmethod
def some_classmethod(cls):
assert cls == DepickledBaseSubclass
return 'it works for realz!'
@staticmethod
def some_staticmethod():
return 'it works for realz!'
@property
def some_property():
return 'it works for realz!'
self.assertEqual(DepickledBaseSubclass().some_method(),
'it works for realz!')
class IncompleteBaseSubclass(depickled_base):
def some_method(self):
return 'this class lacks some concrete methods'
self.assertRaises(TypeError, IncompleteBaseSubclass)
def test_abstracts(self):
# Same as `test_abc` but using deprecated `abc.abstract*` methods.
# See https://github.com/cloudpipe/cloudpickle/issues/367
class AbstractClass(abc.ABC):
@abc.abstractmethod
def some_method(self):
"""A method"""
@abc.abstractclassmethod
def some_classmethod(cls):
"""A classmethod"""
@abc.abstractstaticmethod
def some_staticmethod():
"""A staticmethod"""
@abc.abstractproperty
def some_property(self):
"""A property"""
class ConcreteClass(AbstractClass):
def some_method(self):
return 'it works!'
@classmethod
def some_classmethod(cls):
assert cls == ConcreteClass
return 'it works!'
@staticmethod
def some_staticmethod():
return 'it works!'
@property
def some_property(self):
return 'it works!'
# This abstract class is locally defined so we can safely register
# tuple in it to verify the unpickled class also register tuple.
AbstractClass.register(tuple)
concrete_instance = ConcreteClass()
depickled_base = pickle_depickle(AbstractClass, protocol=self.protocol)
depickled_class = pickle_depickle(ConcreteClass,
protocol=self.protocol)
depickled_instance = pickle_depickle(concrete_instance)
assert issubclass(tuple, AbstractClass)
assert issubclass(tuple, depickled_base)
self.assertEqual(depickled_class().some_method(), 'it works!')
self.assertEqual(depickled_instance.some_method(), 'it works!')
self.assertEqual(depickled_class.some_classmethod(), 'it works!')
self.assertEqual(depickled_instance.some_classmethod(), 'it works!')
self.assertEqual(depickled_class().some_staticmethod(), 'it works!')
self.assertEqual(depickled_instance.some_staticmethod(), 'it works!')
self.assertEqual(depickled_class().some_property, 'it works!')
self.assertEqual(depickled_instance.some_property, 'it works!')
self.assertRaises(TypeError, depickled_base)
class DepickledBaseSubclass(depickled_base):
def some_method(self):
return 'it works for realz!'
@classmethod
def some_classmethod(cls):
assert cls == DepickledBaseSubclass
return 'it works for realz!'
@staticmethod
def some_staticmethod():
return 'it works for realz!'
@property
def some_property(self):
return 'it works for realz!'
self.assertEqual(DepickledBaseSubclass().some_method(),
'it works for realz!')
class IncompleteBaseSubclass(depickled_base):
def some_method(self):
return 'this class lacks some concrete methods'
self.assertRaises(TypeError, IncompleteBaseSubclass)
def test_weakset_identity_preservation(self):
# Test that weaksets don't lose all their inhabitants if they're
# pickled in a larger data structure that includes other references to
# their inhabitants.
class SomeClass:
def __init__(self, x):
self.x = x
obj1, obj2, obj3 = SomeClass(1), SomeClass(2), SomeClass(3)
things = [weakref.WeakSet([obj1, obj2]), obj1, obj2, obj3]
result = pickle_depickle(things, protocol=self.protocol)
weakset, depickled1, depickled2, depickled3 = result
self.assertEqual(depickled1.x, 1)
self.assertEqual(depickled2.x, 2)
self.assertEqual(depickled3.x, 3)
self.assertEqual(len(weakset), 2)
self.assertEqual(set(weakset), {depickled1, depickled2})
def test_non_module_object_passing_whichmodule_test(self):
# https://github.com/cloudpipe/cloudpickle/pull/326: cloudpickle should
# not try to instrospect non-modules object when trying to discover the
# module of a function/class. This happenened because codecov injects
# tuples (and not modules) into sys.modules, but type-checks were not
# carried out on the entries of sys.modules, causing cloupdickle to
# then error in unexpected ways
def func(x):
return x ** 2
# Trigger a loop during the execution of whichmodule(func) by
# explicitly setting the function's module to None
func.__module__ = None
class NonModuleObject:
def __ini__(self):
self.some_attr = None
def __getattr__(self, name):
# We whitelist func so that a _whichmodule(func, None) call
# returns the NonModuleObject instance if a type check on the
# entries of sys.modules is not carried out, but manipulating
# this instance thinking it really is a module later on in the
# pickling process of func errors out
if name == 'func':
return func
else:
raise AttributeError
non_module_object = NonModuleObject()
assert func(2) == 4
assert func is non_module_object.func
# Any manipulation of non_module_object relying on attribute access
# will raise an Exception
with pytest.raises(AttributeError):
_ = non_module_object.some_attr
try:
sys.modules['NonModuleObject'] = non_module_object
func_module_name = _whichmodule(func, None)
assert func_module_name != 'NonModuleObject'
assert func_module_name is None
depickled_func = pickle_depickle(func, protocol=self.protocol)
assert depickled_func(2) == 4
finally:
sys.modules.pop('NonModuleObject')
def test_unrelated_faulty_module(self):
# Check that pickling a dynamically defined function or class does not
# fail when introspecting the currently loaded modules in sys.modules
# as long as those faulty modules are unrelated to the class or
# function we are currently pickling.
for base_class in (object, types.ModuleType):
for module_name in ['_missing_module', None]:
class FaultyModule(base_class):
def __getattr__(self, name):
# This throws an exception while looking up within
# pickle.whichmodule or getattr(module, name, None)
raise Exception()
class Foo:
__module__ = module_name
def foo(self):
return "it works!"
def foo():
return "it works!"
foo.__module__ = module_name
if base_class is types.ModuleType: # noqa
faulty_module = FaultyModule('_faulty_module')
else:
faulty_module = FaultyModule()
sys.modules["_faulty_module"] = faulty_module
try:
# Test whichmodule in save_global.
self.assertEqual(pickle_depickle(Foo()).foo(), "it works!")
# Test whichmodule in save_function.
cloned = pickle_depickle(foo, protocol=self.protocol)
self.assertEqual(cloned(), "it works!")
finally:
sys.modules.pop("_faulty_module", None)
@pytest.mark.skip(reason="fails for pytest v7.2.0")
def test_dynamic_pytest_module(self):
# Test case for pull request https://github.com/cloudpipe/cloudpickle/pull/116
import py
def f():
s = py.builtin.set([1])
return s.pop()
# some setup is required to allow pytest apimodules to be correctly
# serializable.
from srsly.cloudpickle import CloudPickler
from srsly.cloudpickle import cloudpickle_fast as cp_fast
CloudPickler.dispatch_table[type(py.builtin)] = cp_fast._module_reduce
g = cloudpickle.loads(cloudpickle.dumps(f, protocol=self.protocol))
result = g()
self.assertEqual(1, result)
def test_function_module_name(self):
func = lambda x: x
cloned = pickle_depickle(func, protocol=self.protocol)
self.assertEqual(cloned.__module__, func.__module__)
def test_function_qualname(self):
def func(x):
return x
# Default __qualname__ attribute (Python 3 only)
if hasattr(func, '__qualname__'):
cloned = pickle_depickle(func, protocol=self.protocol)
self.assertEqual(cloned.__qualname__, func.__qualname__)
# Mutated __qualname__ attribute
func.__qualname__ = ''
cloned = pickle_depickle(func, protocol=self.protocol)
self.assertEqual(cloned.__qualname__, func.__qualname__)
def test_property(self):
# Note that the @property decorator only has an effect on new-style
# classes.
class MyObject:
_read_only_value = 1
_read_write_value = 1
@property
def read_only_value(self):
"A read-only attribute"
return self._read_only_value
@property
def read_write_value(self):
return self._read_write_value
@read_write_value.setter
def read_write_value(self, value):
self._read_write_value = value
my_object = MyObject()
assert my_object.read_only_value == 1
assert MyObject.read_only_value.__doc__ == "A read-only attribute"
with pytest.raises(AttributeError):
my_object.read_only_value = 2
my_object.read_write_value = 2
depickled_obj = pickle_depickle(my_object)
assert depickled_obj.read_only_value == 1
assert depickled_obj.read_write_value == 2
# make sure the depickled read_only_value attribute is still read-only
with pytest.raises(AttributeError):
my_object.read_only_value = 2
# make sure the depickled read_write_value attribute is writeable
depickled_obj.read_write_value = 3
assert depickled_obj.read_write_value == 3
type(depickled_obj).read_only_value.__doc__ == "A read-only attribute"
def test_namedtuple(self):
MyTuple = collections.namedtuple('MyTuple', ['a', 'b', 'c'])
t1 = MyTuple(1, 2, 3)
t2 = MyTuple(3, 2, 1)
depickled_t1, depickled_MyTuple, depickled_t2 = pickle_depickle(
[t1, MyTuple, t2], protocol=self.protocol)
assert isinstance(depickled_t1, MyTuple)
assert depickled_t1 == t1
assert depickled_MyTuple is MyTuple
assert isinstance(depickled_t2, MyTuple)
assert depickled_t2 == t2
@pytest.mark.skipif(platform.python_implementation() == "PyPy",
reason="fails sometimes for pypy on conda-forge")
def test_interactively_defined_function(self):
# Check that callables defined in the __main__ module of a Python
# script (or jupyter kernel) can be pickled / unpickled / executed.
code = """\
from srsly.tests.cloudpickle.testutils import subprocess_pickle_echo
CONSTANT = 42
class Foo(object):
def method(self, x):
return x
foo = Foo()
def f0(x):
return x ** 2
def f1():
return Foo
def f2(x):
return Foo().method(x)
def f3():
return Foo().method(CONSTANT)
def f4(x):
return foo.method(x)
def f5(x):
# Recursive call to a dynamically defined function.
if x <= 0:
return f4(x)
return f5(x - 1) + 1
cloned = subprocess_pickle_echo(lambda x: x**2, protocol={protocol})
assert cloned(3) == 9
cloned = subprocess_pickle_echo(f0, protocol={protocol})
assert cloned(3) == 9
cloned = subprocess_pickle_echo(Foo, protocol={protocol})
assert cloned().method(2) == Foo().method(2)
cloned = subprocess_pickle_echo(Foo(), protocol={protocol})
assert cloned.method(2) == Foo().method(2)
cloned = subprocess_pickle_echo(f1, protocol={protocol})
assert cloned()().method('a') == f1()().method('a')
cloned = subprocess_pickle_echo(f2, protocol={protocol})
assert cloned(2) == f2(2)
cloned = subprocess_pickle_echo(f3, protocol={protocol})
assert cloned() == f3()
cloned = subprocess_pickle_echo(f4, protocol={protocol})
assert cloned(2) == f4(2)
cloned = subprocess_pickle_echo(f5, protocol={protocol})
assert cloned(7) == f5(7) == 7
""".format(protocol=self.protocol)
assert_run_python_script(textwrap.dedent(code))
def test_interactively_defined_global_variable(self):
# Check that callables defined in the __main__ module of a Python
# script (or jupyter kernel) correctly retrieve global variables.
code_template = """\
from srsly.tests.cloudpickle.testutils import subprocess_pickle_echo
from srsly.cloudpickle import dumps, loads
def local_clone(obj, protocol=None):
return loads(dumps(obj, protocol=protocol))
VARIABLE = "default_value"
def f0():
global VARIABLE
VARIABLE = "changed_by_f0"
def f1():
return VARIABLE
assert f0.__globals__ is f1.__globals__
# pickle f0 and f1 inside the same pickle_string
cloned_f0, cloned_f1 = {clone_func}([f0, f1], protocol={protocol})
# cloned_f0 and cloned_f1 now share a global namespace that is isolated
# from any previously existing namespace
assert cloned_f0.__globals__ is cloned_f1.__globals__
assert cloned_f0.__globals__ is not f0.__globals__
# pickle f1 another time, but in a new pickle string
pickled_f1 = dumps(f1, protocol={protocol})
# Change the value of the global variable in f0's new global namespace
cloned_f0()
# thanks to cloudpickle isolation, depickling and calling f0 and f1
# should not affect the globals of already existing modules
assert VARIABLE == "default_value", VARIABLE
# Ensure that cloned_f1 and cloned_f0 share the same globals, as f1 and
# f0 shared the same globals at pickling time, and cloned_f1 was
# depickled from the same pickle string as cloned_f0
shared_global_var = cloned_f1()
assert shared_global_var == "changed_by_f0", shared_global_var
# f1 is unpickled another time, but because it comes from another
# pickle string than pickled_f1 and pickled_f0, it will not share the
# same globals as the latter two.
new_cloned_f1 = loads(pickled_f1)
assert new_cloned_f1.__globals__ is not cloned_f1.__globals__
assert new_cloned_f1.__globals__ is not f1.__globals__
# get the value of new_cloned_f1's VARIABLE
new_global_var = new_cloned_f1()
assert new_global_var == "default_value", new_global_var
"""
for clone_func in ['local_clone', 'subprocess_pickle_echo']:
code = code_template.format(protocol=self.protocol,
clone_func=clone_func)
assert_run_python_script(textwrap.dedent(code))
def test_closure_interacting_with_a_global_variable(self):
global _TEST_GLOBAL_VARIABLE
assert _TEST_GLOBAL_VARIABLE == "default_value"
orig_value = _TEST_GLOBAL_VARIABLE
try:
def f0():
global _TEST_GLOBAL_VARIABLE
_TEST_GLOBAL_VARIABLE = "changed_by_f0"
def f1():
return _TEST_GLOBAL_VARIABLE
# pickle f0 and f1 inside the same pickle_string
cloned_f0, cloned_f1 = pickle_depickle([f0, f1],
protocol=self.protocol)
# cloned_f0 and cloned_f1 now share a global namespace that is
# isolated from any previously existing namespace
assert cloned_f0.__globals__ is cloned_f1.__globals__
assert cloned_f0.__globals__ is not f0.__globals__
# pickle f1 another time, but in a new pickle string
pickled_f1 = cloudpickle.dumps(f1, protocol=self.protocol)
# Change the global variable's value in f0's new global namespace
cloned_f0()
# depickling f0 and f1 should not affect the globals of already
# existing modules
assert _TEST_GLOBAL_VARIABLE == "default_value"
# Ensure that cloned_f1 and cloned_f0 share the same globals, as f1
# and f0 shared the same globals at pickling time, and cloned_f1
# was depickled from the same pickle string as cloned_f0
shared_global_var = cloned_f1()
assert shared_global_var == "changed_by_f0", shared_global_var
# f1 is unpickled another time, but because it comes from another
# pickle string than pickled_f1 and pickled_f0, it will not share
# the same globals as the latter two.
new_cloned_f1 = pickle.loads(pickled_f1)
assert new_cloned_f1.__globals__ is not cloned_f1.__globals__
assert new_cloned_f1.__globals__ is not f1.__globals__
# get the value of new_cloned_f1's VARIABLE
new_global_var = new_cloned_f1()
assert new_global_var == "default_value", new_global_var
finally:
_TEST_GLOBAL_VARIABLE = orig_value
def test_interactive_remote_function_calls(self):
code = """if __name__ == "__main__":
from srsly.tests.cloudpickle.testutils import subprocess_worker
def interactive_function(x):
return x + 1
with subprocess_worker(protocol={protocol}) as w:
assert w.run(interactive_function, 41) == 42
# Define a new function that will call an updated version of
# the previously called function:
def wrapper_func(x):
return interactive_function(x)
def interactive_function(x):
return x - 1
# The change in the definition of interactive_function in the main
# module of the main process should be reflected transparently
# in the worker process: the worker process does not recall the
# previous definition of `interactive_function`:
assert w.run(wrapper_func, 41) == 40
""".format(protocol=self.protocol)
assert_run_python_script(code)
def test_interactive_remote_function_calls_no_side_effect(self):
code = """if __name__ == "__main__":
from srsly.tests.cloudpickle.testutils import subprocess_worker
import sys
with subprocess_worker(protocol={protocol}) as w:
GLOBAL_VARIABLE = 0
class CustomClass(object):
def mutate_globals(self):
global GLOBAL_VARIABLE
GLOBAL_VARIABLE += 1
return GLOBAL_VARIABLE
custom_object = CustomClass()
assert w.run(custom_object.mutate_globals) == 1
# The caller global variable is unchanged in the main process.
assert GLOBAL_VARIABLE == 0
# Calling the same function again starts again from zero. The
# worker process is stateless: it has no memory of the past call:
assert w.run(custom_object.mutate_globals) == 1
# The symbols defined in the main process __main__ module are
# not set in the worker process main module to leave the worker
# as stateless as possible:
def is_in_main(name):
return hasattr(sys.modules["__main__"], name)
assert is_in_main("CustomClass")
assert not w.run(is_in_main, "CustomClass")
assert is_in_main("GLOBAL_VARIABLE")
assert not w.run(is_in_main, "GLOBAL_VARIABLE")
""".format(protocol=self.protocol)
assert_run_python_script(code)
def test_interactive_dynamic_type_and_remote_instances(self):
code = """if __name__ == "__main__":
from srsly.tests.cloudpickle.testutils import subprocess_worker
with subprocess_worker(protocol={protocol}) as w:
class CustomCounter:
def __init__(self):
self.count = 0
def increment(self):
self.count += 1
return self
counter = CustomCounter().increment()
assert counter.count == 1
returned_counter = w.run(counter.increment)
assert returned_counter.count == 2, returned_counter.count
# Check that the class definition of the returned instance was
# matched back to the original class definition living in __main__.
assert isinstance(returned_counter, CustomCounter)
# Check that memoization does not break provenance tracking:
def echo(*args):
return args
C1, C2, c1, c2 = w.run(echo, CustomCounter, CustomCounter,
CustomCounter(), returned_counter)
assert C1 is CustomCounter
assert C2 is CustomCounter
assert isinstance(c1, CustomCounter)
assert isinstance(c2, CustomCounter)
""".format(protocol=self.protocol)
assert_run_python_script(code)
def test_interactive_dynamic_type_and_stored_remote_instances(self):
"""Simulate objects stored on workers to check isinstance semantics
Such instances stored in the memory of running worker processes are
similar to dask-distributed futures for instance.
"""
code = """if __name__ == "__main__":
import srsly.cloudpickle as cloudpickle, uuid
from srsly.tests.cloudpickle.testutils import subprocess_worker
with subprocess_worker(protocol={protocol}) as w:
class A:
'''Original class definition'''
pass
def store(x):
storage = getattr(cloudpickle, "_test_storage", None)
if storage is None:
storage = cloudpickle._test_storage = dict()
obj_id = uuid.uuid4().hex
storage[obj_id] = x
return obj_id
def lookup(obj_id):
return cloudpickle._test_storage[obj_id]
id1 = w.run(store, A())
# The stored object on the worker is matched to a singleton class
# definition thanks to provenance tracking:
assert w.run(lambda obj_id: isinstance(lookup(obj_id), A), id1)
# Retrieving the object from the worker yields a local copy that
# is matched back the local class definition this instance
# originally stems from.
assert isinstance(w.run(lookup, id1), A)
# Changing the local class definition should be taken into account
# in all subsequent calls. In particular the old instances on the
# worker do not map back to the new class definition, neither on
# the worker itself, nor locally on the main program when the old
# instance is retrieved:
class A:
'''Updated class definition'''
pass
assert not w.run(lambda obj_id: isinstance(lookup(obj_id), A), id1)
retrieved1 = w.run(lookup, id1)
assert not isinstance(retrieved1, A)
assert retrieved1.__class__ is not A
assert retrieved1.__class__.__doc__ == "Original class definition"
# New instances on the other hand are proper instances of the new
# class definition everywhere:
a = A()
id2 = w.run(store, a)
assert w.run(lambda obj_id: isinstance(lookup(obj_id), A), id2)
assert isinstance(w.run(lookup, id2), A)
# Monkeypatch the class defintion in the main process to a new
# class method:
A.echo = lambda cls, x: x
# Calling this method on an instance will automatically update
# the remote class definition on the worker to propagate the monkey
# patch dynamically.
assert w.run(a.echo, 42) == 42
# The stored instance can therefore also access the new class
# method:
assert w.run(lambda obj_id: lookup(obj_id).echo(43), id2) == 43
""".format(protocol=self.protocol)
assert_run_python_script(code)
@pytest.mark.skip(reason="Seems to have issues outside of linux and CPython")
def test_interactive_remote_function_calls_no_memory_leak(self):
code = """if __name__ == "__main__":
from srsly.tests.cloudpickle.testutils import subprocess_worker
import struct
with subprocess_worker(protocol={protocol}) as w:
reference_size = w.memsize()
assert reference_size > 0
def make_big_closure(i):
# Generate a byte string of size 1MB
itemsize = len(struct.pack("l", 1))
data = struct.pack("l", i) * (int(1e6) // itemsize)
def process_data():
return len(data)
return process_data
for i in range(100):
func = make_big_closure(i)
result = w.run(func)
assert result == int(1e6), result
import gc
w.run(gc.collect)
# By this time the worker process has processed 100MB worth of data
# passed in the closures. The worker memory size should not have
# grown by more than a few MB as closures are garbage collected at
# the end of each remote function call.
growth = w.memsize() - reference_size
# For some reason, the memory growth after processing 100MB of
# data is ~10MB on MacOS, and ~1MB on Linux, so the upper bound on
# memory growth we use is only tight for MacOS. However,
# - 10MB is still 10x lower than the expected memory growth in case
# of a leak (which would be the total size of the processed data,
# 100MB)
# - the memory usage growth does not increase if using 10000
# iterations instead of 100 as used now (100x more data)
assert growth < 1.5e7, growth
""".format(protocol=self.protocol)
assert_run_python_script(code)
def test_pickle_reraise(self):
for exc_type in [Exception, ValueError, TypeError, RuntimeError]:
obj = RaiserOnPickle(exc_type("foo"))
with pytest.raises((exc_type, pickle.PicklingError)):
cloudpickle.dumps(obj, protocol=self.protocol)
def test_unhashable_function(self):
d = {'a': 1}
depickled_method = pickle_depickle(d.get, protocol=self.protocol)
self.assertEqual(depickled_method('a'), 1)
self.assertEqual(depickled_method('b'), None)
@pytest.mark.skipif(sys.version_info >= (3, 12), reason="Deprecation warning in python 3.12 about future deprecation in python 3.14")
def test_itertools_count(self):
counter = itertools.count(1, step=2)
# advance the counter a bit
next(counter)
next(counter)
new_counter = pickle_depickle(counter, protocol=self.protocol)
self.assertTrue(counter is not new_counter)
for _ in range(10):
self.assertEqual(next(counter), next(new_counter))
def test_wraps_preserves_function_name(self):
from functools import wraps
def f():
pass
@wraps(f)
def g():
f()
f2 = pickle_depickle(g, protocol=self.protocol)
self.assertEqual(f2.__name__, f.__name__)
def test_wraps_preserves_function_doc(self):
from functools import wraps
def f():
"""42"""
pass
@wraps(f)
def g():
f()
f2 = pickle_depickle(g, protocol=self.protocol)
self.assertEqual(f2.__doc__, f.__doc__)
def test_wraps_preserves_function_annotations(self):
def f(x):
pass
f.__annotations__ = {'x': 1, 'return': float}
@wraps(f)
def g(x):
f(x)
f2 = pickle_depickle(g, protocol=self.protocol)
self.assertEqual(f2.__annotations__, f.__annotations__)
def test_type_hint(self):
t = typing.Union[list, int]
assert pickle_depickle(t) == t
def test_instance_with_slots(self):
for slots in [["registered_attribute"], "registered_attribute"]:
class ClassWithSlots:
__slots__ = slots
def __init__(self):
self.registered_attribute = 42
initial_obj = ClassWithSlots()
depickled_obj = pickle_depickle(
initial_obj, protocol=self.protocol)
for obj in [initial_obj, depickled_obj]:
self.assertEqual(obj.registered_attribute, 42)
with pytest.raises(AttributeError):
obj.non_registered_attribute = 1
class SubclassWithSlots(ClassWithSlots):
def __init__(self):
self.unregistered_attribute = 1
obj = SubclassWithSlots()
s = cloudpickle.dumps(obj, protocol=self.protocol)
del SubclassWithSlots
depickled_obj = cloudpickle.loads(s)
assert depickled_obj.unregistered_attribute == 1
@unittest.skipIf(not hasattr(types, "MappingProxyType"),
"Old versions of Python do not have this type.")
def test_mappingproxy(self):
mp = types.MappingProxyType({"some_key": "some value"})
assert mp == pickle_depickle(mp, protocol=self.protocol)
def test_dataclass(self):
dataclasses = pytest.importorskip("dataclasses")
DataClass = dataclasses.make_dataclass('DataClass', [('x', int)])
data = DataClass(x=42)
pickle_depickle(DataClass, protocol=self.protocol)
assert data.x == pickle_depickle(data, protocol=self.protocol).x == 42
def test_locally_defined_enum(self):
class StringEnum(str, enum.Enum):
"""Enum when all members are also (and must be) strings"""
class Color(StringEnum):
"""3-element color space"""
RED = "1"
GREEN = "2"
BLUE = "3"
def is_green(self):
return self is Color.GREEN
green1, green2, ClonedColor = pickle_depickle(
[Color.GREEN, Color.GREEN, Color], protocol=self.protocol)
assert green1 is green2
assert green1 is ClonedColor.GREEN
assert green1 is not ClonedColor.BLUE
assert isinstance(green1, str)
assert green1.is_green()
# cloudpickle systematically tracks provenance of class definitions
# and ensure reconciliation in case of round trips:
assert green1 is Color.GREEN
assert ClonedColor is Color
green3 = pickle_depickle(Color.GREEN, protocol=self.protocol)
assert green3 is Color.GREEN
def test_locally_defined_intenum(self):
# Try again with a IntEnum defined with the functional API
DynamicColor = enum.IntEnum("Color", {"RED": 1, "GREEN": 2, "BLUE": 3})
green1, green2, ClonedDynamicColor = pickle_depickle(
[DynamicColor.GREEN, DynamicColor.GREEN, DynamicColor],
protocol=self.protocol)
assert green1 is green2
assert green1 is ClonedDynamicColor.GREEN
assert green1 is not ClonedDynamicColor.BLUE
assert ClonedDynamicColor is DynamicColor
def test_interactively_defined_enum(self):
code = """if __name__ == "__main__":
from enum import Enum
from srsly.tests.cloudpickle.testutils import subprocess_worker
with subprocess_worker(protocol={protocol}) as w:
class Color(Enum):
RED = 1
GREEN = 2
def check_positive(x):
return Color.GREEN if x >= 0 else Color.RED
result = w.run(check_positive, 1)
# Check that the returned enum instance is reconciled with the
# locally defined Color enum type definition:
assert result is Color.GREEN
# Check that changing the definition of the Enum class is taken
# into account on the worker for subsequent calls:
class Color(Enum):
RED = 1
BLUE = 2
def check_positive(x):
return Color.BLUE if x >= 0 else Color.RED
result = w.run(check_positive, 1)
assert result is Color.BLUE
""".format(protocol=self.protocol)
assert_run_python_script(code)
def test_relative_import_inside_function(self):
pytest.importorskip("_cloudpickle_testpkg")
# Make sure relative imports inside round-tripped functions is not
# broken. This was a bug in cloudpickle versions <= 0.5.3 and was
# re-introduced in 0.8.0.
from _cloudpickle_testpkg import relative_imports_factory
f, g = relative_imports_factory()
for func, source in zip([f, g], ["module", "package"]):
# Make sure relative imports are initially working
assert func() == f"hello from a {source}!"
# Make sure relative imports still work after round-tripping
cloned_func = pickle_depickle(func, protocol=self.protocol)
assert cloned_func() == f"hello from a {source}!"
def test_interactively_defined_func_with_keyword_only_argument(self):
# fixes https://github.com/cloudpipe/cloudpickle/issues/263
def f(a, *, b=1):
return a + b
depickled_f = pickle_depickle(f, protocol=self.protocol)
for func in (f, depickled_f):
assert func(2) == 3
assert func.__kwdefaults__ == {'b': 1}
@pytest.mark.skipif(not hasattr(types.CodeType, "co_posonlyargcount"),
reason="Requires positional-only argument syntax")
def test_interactively_defined_func_with_positional_only_argument(self):
# Fixes https://github.com/cloudpipe/cloudpickle/issues/266
# The source code of this test is bundled in a string and is ran from
# the __main__ module of a subprocess in order to avoid a SyntaxError
# in versions of python that do not support positional-only argument
# syntax.
code = """
import pytest
from srsly.cloudpickle import loads, dumps
def f(a, /, b=1):
return a + b
depickled_f = loads(dumps(f, protocol={protocol}))
for func in (f, depickled_f):
assert func(2) == 3
assert func.__code__.co_posonlyargcount == 1
with pytest.raises(TypeError):
func(a=2)
""".format(protocol=self.protocol)
assert_run_python_script(textwrap.dedent(code))
def test___reduce___returns_string(self):
# Non regression test for objects with a __reduce__ method returning a
# string, meaning "save by attribute using save_global"
pytest.importorskip("_cloudpickle_testpkg")
from _cloudpickle_testpkg import some_singleton
assert some_singleton.__reduce__() == "some_singleton"
depickled_singleton = pickle_depickle(
some_singleton, protocol=self.protocol)
assert depickled_singleton is some_singleton
def test_cloudpickle_extract_nested_globals(self):
def function_factory():
def inner_function():
global _TEST_GLOBAL_VARIABLE
return _TEST_GLOBAL_VARIABLE
return inner_function
globals_ = set(cloudpickle.cloudpickle._extract_code_globals(
function_factory.__code__).keys())
assert globals_ == {'_TEST_GLOBAL_VARIABLE'}
depickled_factory = pickle_depickle(function_factory,
protocol=self.protocol)
inner_func = depickled_factory()
assert inner_func() == _TEST_GLOBAL_VARIABLE
def test_recursion_during_pickling(self):
class A:
def __getattribute__(self, name):
return getattr(self, name)
a = A()
with pytest.raises(pickle.PicklingError, match='recursion'):
cloudpickle.dumps(a)
def test_out_of_band_buffers(self):
if self.protocol < 5:
pytest.skip("Need Pickle Protocol 5 or later")
np = pytest.importorskip("numpy")
class LocallyDefinedClass:
data = np.zeros(10)
data_instance = LocallyDefinedClass()
buffers = []
pickle_bytes = cloudpickle.dumps(data_instance, protocol=self.protocol,
buffer_callback=buffers.append)
assert len(buffers) == 1
reconstructed = pickle.loads(pickle_bytes, buffers=buffers)
np.testing.assert_allclose(reconstructed.data, data_instance.data)
def test_pickle_dynamic_typevar(self):
T = typing.TypeVar('T')
depickled_T = pickle_depickle(T, protocol=self.protocol)
attr_list = [
"__name__", "__bound__", "__constraints__", "__covariant__",
"__contravariant__"
]
for attr in attr_list:
assert getattr(T, attr) == getattr(depickled_T, attr)
def test_pickle_dynamic_typevar_tracking(self):
T = typing.TypeVar("T")
T2 = subprocess_pickle_echo(T, protocol=self.protocol)
assert T is T2
def test_pickle_dynamic_typevar_memoization(self):
T = typing.TypeVar('T')
depickled_T1, depickled_T2 = pickle_depickle((T, T),
protocol=self.protocol)
assert depickled_T1 is depickled_T2
def test_pickle_importable_typevar(self):
pytest.importorskip("_cloudpickle_testpkg")
from _cloudpickle_testpkg import T
T1 = pickle_depickle(T, protocol=self.protocol)
assert T1 is T
# Standard Library TypeVar
from typing import AnyStr
assert AnyStr is pickle_depickle(AnyStr, protocol=self.protocol)
def test_generic_type(self):
T = typing.TypeVar('T')
class C(typing.Generic[T]):
pass
assert pickle_depickle(C, protocol=self.protocol) is C
# Identity is not part of the typing contract: only test for
# equality instead.
assert pickle_depickle(C[int], protocol=self.protocol) == C[int]
with subprocess_worker(protocol=self.protocol) as worker:
def check_generic(generic, origin, type_value, use_args):
assert generic.__origin__ is origin
assert len(origin.__orig_bases__) == 1
ob = origin.__orig_bases__[0]
assert ob.__origin__ is typing.Generic
if use_args:
assert len(generic.__args__) == 1
assert generic.__args__[0] is type_value
else:
assert len(generic.__parameters__) == 1
assert generic.__parameters__[0] is type_value
assert len(ob.__parameters__) == 1
return "ok"
# backward-compat for old Python 3.5 versions that sometimes relies
# on __parameters__
use_args = getattr(C[int], '__args__', ()) != ()
assert check_generic(C[int], C, int, use_args) == "ok"
assert worker.run(check_generic, C[int], C, int, use_args) == "ok"
def test_generic_subclass(self):
T = typing.TypeVar('T')
class Base(typing.Generic[T]):
pass
class DerivedAny(Base):
pass
class LeafAny(DerivedAny):
pass
class DerivedInt(Base[int]):
pass
class LeafInt(DerivedInt):
pass
class DerivedT(Base[T]):
pass
class LeafT(DerivedT[T]):
pass
klasses = [
Base, DerivedAny, LeafAny, DerivedInt, LeafInt, DerivedT, LeafT
]
for klass in klasses:
assert pickle_depickle(klass, protocol=self.protocol) is klass
with subprocess_worker(protocol=self.protocol) as worker:
def check_mro(klass, expected_mro):
assert klass.mro() == expected_mro
return "ok"
for klass in klasses:
mro = klass.mro()
assert check_mro(klass, mro)
assert worker.run(check_mro, klass, mro) == "ok"
def test_locally_defined_class_with_type_hints(self):
with subprocess_worker(protocol=self.protocol) as worker:
for type_ in _all_types_to_test():
class MyClass:
def method(self, arg: type_) -> type_:
return arg
MyClass.__annotations__ = {'attribute': type_}
def check_annotations(obj, expected_type, expected_type_str):
assert obj.__annotations__["attribute"] == expected_type
assert (
obj.method.__annotations__["arg"] == expected_type
)
assert (
obj.method.__annotations__["return"]
== expected_type
)
return "ok"
obj = MyClass()
assert check_annotations(obj, type_, "type_") == "ok"
assert (
worker.run(check_annotations, obj, type_, "type_") == "ok"
)
def test_generic_extensions_literal(self):
typing_extensions = pytest.importorskip('typing_extensions')
for obj in [typing_extensions.Literal, typing_extensions.Literal['a']]:
depickled_obj = pickle_depickle(obj, protocol=self.protocol)
assert depickled_obj == obj
def test_generic_extensions_final(self):
typing_extensions = pytest.importorskip('typing_extensions')
for obj in [typing_extensions.Final, typing_extensions.Final[int]]:
depickled_obj = pickle_depickle(obj, protocol=self.protocol)
assert depickled_obj == obj
def test_class_annotations(self):
class C:
pass
C.__annotations__ = {'a': int}
C1 = pickle_depickle(C, protocol=self.protocol)
assert C1.__annotations__ == C.__annotations__
def test_function_annotations(self):
def f(a: int) -> str:
pass
f1 = pickle_depickle(f, protocol=self.protocol)
assert f1.__annotations__ == f.__annotations__
def test_always_use_up_to_date_copyreg(self):
# test that updates of copyreg.dispatch_table are taken in account by
# cloudpickle
import copyreg
try:
class MyClass:
pass
def reduce_myclass(x):
return MyClass, (), {'custom_reduce': True}
copyreg.dispatch_table[MyClass] = reduce_myclass
my_obj = MyClass()
depickled_myobj = pickle_depickle(my_obj, protocol=self.protocol)
assert hasattr(depickled_myobj, 'custom_reduce')
finally:
copyreg.dispatch_table.pop(MyClass)
def test_literal_misdetection(self):
# see https://github.com/cloudpipe/cloudpickle/issues/403
class MyClass:
@property
def __values__(self):
return ()
o = MyClass()
pickle_depickle(o, protocol=self.protocol)
def test_final_or_classvar_misdetection(self):
# see https://github.com/cloudpipe/cloudpickle/issues/403
class MyClass:
@property
def __type__(self):
return int
o = MyClass()
pickle_depickle(o, protocol=self.protocol)
@pytest.mark.skip(reason="Requires pytest -s to pass")
def test_pickle_constructs_from_module_registered_for_pickling_by_value(self): # noqa
_prev_sys_path = sys.path.copy()
try:
# We simulate an interactive session that:
# - we start from the /path/to/cloudpickle/tests directory, where a
# local .py file (mock_local_file) is located.
# - uses constructs from mock_local_file in remote workers that do
# not have access to this file. This situation is
# the justification behind the
# (un)register_pickle_by_value(module) api that cloudpickle
# exposes.
_mock_interactive_session_cwd = os.path.dirname(__file__)
# First, remove sys.path entries that could point to
# /path/to/cloudpickle/tests and be in inherited by the worker
_maybe_remove(sys.path, '')
_maybe_remove(sys.path, _mock_interactive_session_cwd)
# Add the desired session working directory
sys.path.insert(0, _mock_interactive_session_cwd)
with subprocess_worker(protocol=self.protocol) as w:
# Make the module unavailable in the remote worker
w.run(
lambda p: sys.path.remove(p), _mock_interactive_session_cwd
)
# Import the actual file after starting the module since the
# worker is started using fork on Linux, which will inherits
# the parent sys.modules. On Python>3.6, the worker can be
# started using spawn using mp_context in ProcessPoolExectutor.
# TODO Once Python 3.6 reaches end of life, rely on mp_context
# instead.
import mock_local_folder.mod as mod
# The constructs whose pickling mechanism is changed using
# register_pickle_by_value are functions, classes, TypeVar and
# modules.
from mock_local_folder.mod import (
local_function, LocalT, LocalClass
)
# Make sure the module/constructs are unimportable in the
# worker.
with pytest.raises(ImportError):
w.run(lambda: __import__("mock_local_folder.mod"))
with pytest.raises(ImportError):
w.run(
lambda: __import__("mock_local_folder.subfolder.mod")
)
for o in [mod, local_function, LocalT, LocalClass]:
with pytest.raises(ImportError):
w.run(lambda: o)
register_pickle_by_value(mod)
# function
assert w.run(lambda: local_function()) == local_function()
# typevar
assert w.run(lambda: LocalT.__name__) == LocalT.__name__
# classes
assert (
w.run(lambda: LocalClass().method())
== LocalClass().method()
)
# modules
assert (
w.run(lambda: mod.local_function()) == local_function()
)
# Constructs from modules inside subfolders should be pickled
# by value if a namespace module pointing to some parent folder
# was registered for pickling by value. A "mock_local_folder"
# namespace module falls into that category, but a
# "mock_local_folder.mod" one does not.
from mock_local_folder.subfolder.submod import (
LocalSubmodClass, LocalSubmodT, local_submod_function
)
# Shorter aliases to comply with line-length limits
_t, _func, _class = (
LocalSubmodT, local_submod_function, LocalSubmodClass
)
with pytest.raises(ImportError):
w.run(
lambda: __import__("mock_local_folder.subfolder.mod")
)
with pytest.raises(ImportError):
w.run(lambda: local_submod_function)
unregister_pickle_by_value(mod)
with pytest.raises(ImportError):
w.run(lambda: local_function)
with pytest.raises(ImportError):
w.run(lambda: __import__("mock_local_folder.mod"))
# Test the namespace folder case
import mock_local_folder
register_pickle_by_value(mock_local_folder)
assert w.run(lambda: local_function()) == local_function()
assert w.run(lambda: _func()) == _func()
unregister_pickle_by_value(mock_local_folder)
with pytest.raises(ImportError):
w.run(lambda: local_function)
with pytest.raises(ImportError):
w.run(lambda: local_submod_function)
# Test the case of registering a single module inside a
# subfolder.
import mock_local_folder.subfolder.submod
register_pickle_by_value(mock_local_folder.subfolder.submod)
assert w.run(lambda: _func()) == _func()
assert w.run(lambda: _t.__name__) == _t.__name__
assert w.run(lambda: _class().method()) == _class().method()
# Registering a module from a subfolder for pickling by value
# should not make constructs from modules from the parent
# folder pickleable
with pytest.raises(ImportError):
w.run(lambda: local_function)
with pytest.raises(ImportError):
w.run(lambda: __import__("mock_local_folder.mod"))
unregister_pickle_by_value(
mock_local_folder.subfolder.submod
)
with pytest.raises(ImportError):
w.run(lambda: local_submod_function)
# Test the subfolder namespace module case
import mock_local_folder.subfolder
register_pickle_by_value(mock_local_folder.subfolder)
assert w.run(lambda: _func()) == _func()
assert w.run(lambda: _t.__name__) == _t.__name__
assert w.run(lambda: _class().method()) == _class().method()
unregister_pickle_by_value(mock_local_folder.subfolder)
finally:
_fname = "mock_local_folder"
sys.path = _prev_sys_path
for m in [_fname, f"{_fname}.mod", f"{_fname}.subfolder",
f"{_fname}.subfolder.submod"]:
mod = sys.modules.pop(m, None)
if mod and mod.__name__ in list_registry_pickle_by_value():
unregister_pickle_by_value(mod)
def test_pickle_constructs_from_installed_packages_registered_for_pickling_by_value( # noqa
self
):
pytest.importorskip("_cloudpickle_testpkg")
for package_or_module in ["package", "module"]:
if package_or_module == "package":
import _cloudpickle_testpkg as m
f = m.package_function_with_global
_original_global = m.global_variable
elif package_or_module == "module":
import _cloudpickle_testpkg.mod as m
f = m.module_function_with_global
_original_global = m.global_variable
try:
with subprocess_worker(protocol=self.protocol) as w:
assert w.run(lambda: f()) == _original_global
# Test that f is pickled by value by modifying a global
# variable that f uses, and making sure that this
# modification shows up when calling the function remotely
register_pickle_by_value(m)
assert w.run(lambda: f()) == _original_global
m.global_variable = "modified global"
assert m.global_variable != _original_global
assert w.run(lambda: f()) == "modified global"
unregister_pickle_by_value(m)
finally:
m.global_variable = _original_global
if m.__name__ in list_registry_pickle_by_value():
unregister_pickle_by_value(m)
def test_pickle_various_versions_of_the_same_function_with_different_pickling_method( # noqa
self
):
pytest.importorskip("_cloudpickle_testpkg")
# Make sure that different versions of the same function (possibly
# pickled in a different way - by value and/or by reference) can
# peacefully co-exist (e.g. without globals interaction) in a remote
# worker.
import _cloudpickle_testpkg
from _cloudpickle_testpkg import package_function_with_global as f
_original_global = _cloudpickle_testpkg.global_variable
def _create_registry():
_main = __import__("sys").modules["__main__"]
_main._cloudpickle_registry = {}
# global _cloudpickle_registry
def _add_to_registry(v, k):
_main = __import__("sys").modules["__main__"]
_main._cloudpickle_registry[k] = v
def _call_from_registry(k):
_main = __import__("sys").modules["__main__"]
return _main._cloudpickle_registry[k]()
try:
with subprocess_worker(protocol=self.protocol) as w:
w.run(_create_registry)
w.run(_add_to_registry, f, "f_by_ref")
register_pickle_by_value(_cloudpickle_testpkg)
_cloudpickle_testpkg.global_variable = "modified global"
w.run(_add_to_registry, f, "f_by_val")
assert (
w.run(_call_from_registry, "f_by_ref") == _original_global
)
assert (
w.run(_call_from_registry, "f_by_val") == "modified global"
)
finally:
_cloudpickle_testpkg.global_variable = _original_global
if "_cloudpickle_testpkg" in list_registry_pickle_by_value():
unregister_pickle_by_value(_cloudpickle_testpkg)
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="Determinism can only be guaranteed for Python 3.7+"
)
def test_deterministic_pickle_bytes_for_function(self):
# Ensure that functions with references to several global names are
# pickled to fixed bytes that do not depend on the PYTHONHASHSEED of
# the Python process.
vals = set()
def func_with_globals():
return _TEST_GLOBAL_VARIABLE + _TEST_GLOBAL_VARIABLE2
for i in range(5):
vals.add(
subprocess_pickle_string(func_with_globals,
protocol=self.protocol,
add_env={"PYTHONHASHSEED": str(i)}))
if len(vals) > 1:
# Print additional debug info on stdout with dis:
for val in vals:
pickletools.dis(val)
pytest.fail(
"Expected a single deterministic payload, got %d/5" % len(vals)
)
class Protocol2CloudPickleTest(CloudPickleTest):
protocol = 2
def test_lookup_module_and_qualname_dynamic_typevar():
T = typing.TypeVar('T')
module_and_name = _lookup_module_and_qualname(T, name=T.__name__)
assert module_and_name is None
def test_lookup_module_and_qualname_importable_typevar():
pytest.importorskip("_cloudpickle_testpkg")
import _cloudpickle_testpkg
T = _cloudpickle_testpkg.T
module_and_name = _lookup_module_and_qualname(T, name=T.__name__)
assert module_and_name is not None
module, name = module_and_name
assert module is _cloudpickle_testpkg
assert name == 'T'
def test_lookup_module_and_qualname_stdlib_typevar():
module_and_name = _lookup_module_and_qualname(typing.AnyStr,
name=typing.AnyStr.__name__)
assert module_and_name is not None
module, name = module_and_name
assert module is typing
assert name == 'AnyStr'
def test_register_pickle_by_value():
pytest.importorskip("_cloudpickle_testpkg")
import _cloudpickle_testpkg as pkg
import _cloudpickle_testpkg.mod as mod
assert list_registry_pickle_by_value() == set()
register_pickle_by_value(pkg)
assert list_registry_pickle_by_value() == {pkg.__name__}
register_pickle_by_value(mod)
assert list_registry_pickle_by_value() == {pkg.__name__, mod.__name__}
unregister_pickle_by_value(mod)
assert list_registry_pickle_by_value() == {pkg.__name__}
msg = f"Input should be a module object, got {pkg.__name__} instead"
with pytest.raises(ValueError, match=msg):
unregister_pickle_by_value(pkg.__name__)
unregister_pickle_by_value(pkg)
assert list_registry_pickle_by_value() == set()
msg = f"{pkg} is not registered for pickle by value"
with pytest.raises(ValueError, match=re.escape(msg)):
unregister_pickle_by_value(pkg)
msg = f"Input should be a module object, got {pkg.__name__} instead"
with pytest.raises(ValueError, match=msg):
register_pickle_by_value(pkg.__name__)
dynamic_mod = types.ModuleType('dynamic_mod')
msg = (
f"{dynamic_mod} was not imported correctly, have you used an "
f"`import` statement to access it?"
)
with pytest.raises(ValueError, match=re.escape(msg)):
register_pickle_by_value(dynamic_mod)
def _all_types_to_test():
T = typing.TypeVar('T')
class C(typing.Generic[T]):
pass
types_to_test = [
C, C[int],
T, typing.Any, typing.Optional,
typing.Generic, typing.Union,
typing.Optional[int],
typing.Generic[T],
typing.Callable[[int], typing.Any],
typing.Callable[..., typing.Any],
typing.Callable[[], typing.Any],
typing.Tuple[int, ...],
typing.Tuple[int, C[int]],
typing.List[int],
typing.Dict[int, str],
typing.ClassVar,
typing.ClassVar[C[int]],
typing.NoReturn,
]
return types_to_test
def test_module_level_pickler():
# #366: cloudpickle should expose its pickle.Pickler subclass as
# cloudpickle.Pickler
assert hasattr(cloudpickle, "Pickler")
assert cloudpickle.Pickler is cloudpickle.CloudPickler
if __name__ == '__main__':
unittest.main()
srsly-release-v2.5.1/srsly/tests/cloudpickle/mock_local_folder/ 0000775 0000000 0000000 00000000000 14742310675 0024762 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/cloudpickle/mock_local_folder/mod.py 0000664 0000000 0000000 00000001136 14742310675 0026114 0 ustar 00root root 0000000 0000000 """
In the distributed computing setting, this file plays the role of a "local
development" file, e.g. a file that is importable locally, but unimportable in
remote workers. Constructs defined in this file and usually pickled by
reference should instead flagged to cloudpickle for pickling by value: this is
done using the register_pickle_by_value api exposed by cloudpickle.
"""
import typing
def local_function():
return "hello from a function importable locally!"
class LocalClass:
def method(self):
return "hello from a class importable locally"
LocalT = typing.TypeVar("LocalT")
srsly-release-v2.5.1/srsly/tests/cloudpickle/mock_local_folder/subfolder/ 0000775 0000000 0000000 00000000000 14742310675 0026747 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/cloudpickle/mock_local_folder/subfolder/submod.py 0000664 0000000 0000000 00000000446 14742310675 0030616 0 ustar 00root root 0000000 0000000 import typing
def local_submod_function():
return "hello from a file located in a locally-importable subfolder!"
class LocalSubmodClass:
def method(self):
return "hello from a class located in a locally-importable subfolder!"
LocalSubmodT = typing.TypeVar("LocalSubmodT")
srsly-release-v2.5.1/srsly/tests/cloudpickle/testutils.py 0000664 0000000 0000000 00000016452 14742310675 0023746 0 ustar 00root root 0000000 0000000 import sys
import os
import os.path as op
import tempfile
from subprocess import Popen, check_output, PIPE, STDOUT, CalledProcessError
from srsly.cloudpickle.compat import pickle
from contextlib import contextmanager
from concurrent.futures import ProcessPoolExecutor
import psutil
from srsly.cloudpickle import dumps
from subprocess import TimeoutExpired
loads = pickle.loads
TIMEOUT = 60
TEST_GLOBALS = "a test value"
def make_local_function():
def g(x):
# this function checks that the globals are correctly handled and that
# the builtins are available
assert TEST_GLOBALS == "a test value"
return sum(range(10))
return g
def _make_cwd_env():
"""Helper to prepare environment for the child processes"""
cloudpickle_repo_folder = op.normpath(
op.join(op.dirname(__file__), '..'))
env = os.environ.copy()
pythonpath = "{src}{sep}tests{pathsep}{src}".format(
src=cloudpickle_repo_folder, sep=os.sep, pathsep=os.pathsep)
env['PYTHONPATH'] = pythonpath
return cloudpickle_repo_folder, env
def subprocess_pickle_string(input_data, protocol=None, timeout=TIMEOUT,
add_env=None):
"""Retrieve pickle string of an object generated by a child Python process
Pickle the input data into a buffer, send it to a subprocess via
stdin, expect the subprocess to unpickle, re-pickle that data back
and send it back to the parent process via stdout for final unpickling.
>>> testutils.subprocess_pickle_string([1, 'a', None], protocol=2)
b'\x80\x02]q\x00(K\x01X\x01\x00\x00\x00aq\x01Ne.'
"""
# run then pickle_echo(protocol=protocol) in __main__:
# Protect stderr from any warning, as we will assume an error will happen
# if it is not empty. A concrete example is pytest using the imp module,
# which is deprecated in python 3.8
cmd = [sys.executable, '-W ignore', __file__, "--protocol", str(protocol)]
cwd, env = _make_cwd_env()
if add_env:
env.update(add_env)
proc = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, cwd=cwd, env=env,
bufsize=4096)
pickle_string = dumps(input_data, protocol=protocol)
try:
comm_kwargs = {}
comm_kwargs['timeout'] = timeout
out, err = proc.communicate(pickle_string, **comm_kwargs)
if proc.returncode != 0 or len(err):
message = "Subprocess returned %d: " % proc.returncode
message += err.decode('utf-8')
raise RuntimeError(message)
return out
except TimeoutExpired as e:
proc.kill()
out, err = proc.communicate()
message = "\n".join([out.decode('utf-8'), err.decode('utf-8')])
raise RuntimeError(message) from e
def subprocess_pickle_echo(input_data, protocol=None, timeout=TIMEOUT,
add_env=None):
"""Echo function with a child Python process
Pickle the input data into a buffer, send it to a subprocess via
stdin, expect the subprocess to unpickle, re-pickle that data back
and send it back to the parent process via stdout for final unpickling.
>>> subprocess_pickle_echo([1, 'a', None])
[1, 'a', None]
"""
out = subprocess_pickle_string(input_data,
protocol=protocol,
timeout=timeout,
add_env=add_env)
return loads(out)
def _read_all_bytes(stream_in, chunk_size=4096):
all_data = b""
while True:
data = stream_in.read(chunk_size)
all_data += data
if len(data) < chunk_size:
break
return all_data
def pickle_echo(stream_in=None, stream_out=None, protocol=None):
"""Read a pickle from stdin and pickle it back to stdout"""
if stream_in is None:
stream_in = sys.stdin
if stream_out is None:
stream_out = sys.stdout
# Force the use of bytes streams under Python 3
if hasattr(stream_in, 'buffer'):
stream_in = stream_in.buffer
if hasattr(stream_out, 'buffer'):
stream_out = stream_out.buffer
input_bytes = _read_all_bytes(stream_in)
stream_in.close()
obj = loads(input_bytes)
repickled_bytes = dumps(obj, protocol=protocol)
stream_out.write(repickled_bytes)
stream_out.close()
def call_func(payload, protocol):
"""Remote function call that uses cloudpickle to transport everthing"""
func, args, kwargs = loads(payload)
try:
result = func(*args, **kwargs)
except BaseException as e:
result = e
return dumps(result, protocol=protocol)
class _Worker:
def __init__(self, protocol=None):
self.protocol = protocol
self.pool = ProcessPoolExecutor(max_workers=1)
self.pool.submit(id, 42).result() # start the worker process
def run(self, func, *args, **kwargs):
"""Synchronous remote function call"""
input_payload = dumps((func, args, kwargs), protocol=self.protocol)
result_payload = self.pool.submit(
call_func, input_payload, self.protocol).result()
result = loads(result_payload)
if isinstance(result, BaseException):
raise result
return result
def memsize(self):
workers_pids = [p.pid if hasattr(p, "pid") else p
for p in list(self.pool._processes)]
num_workers = len(workers_pids)
if num_workers == 0:
return 0
elif num_workers > 1:
raise RuntimeError("Unexpected number of workers: %d"
% num_workers)
return psutil.Process(workers_pids[0]).memory_info().rss
def close(self):
self.pool.shutdown(wait=True)
@contextmanager
def subprocess_worker(protocol=None):
worker = _Worker(protocol=protocol)
yield worker
worker.close()
def assert_run_python_script(source_code, timeout=TIMEOUT):
"""Utility to help check pickleability of objects defined in __main__
The script provided in the source code should return 0 and not print
anything on stderr or stdout.
"""
fd, source_file = tempfile.mkstemp(suffix='_src_test_cloudpickle.py')
os.close(fd)
try:
with open(source_file, 'wb') as f:
f.write(source_code.encode('utf-8'))
cmd = [sys.executable, '-W ignore', source_file]
cwd, env = _make_cwd_env()
kwargs = {
'cwd': cwd,
'stderr': STDOUT,
'env': env,
}
# If coverage is running, pass the config file to the subprocess
coverage_rc = os.environ.get("COVERAGE_PROCESS_START")
if coverage_rc:
kwargs['env']['COVERAGE_PROCESS_START'] = coverage_rc
kwargs['timeout'] = timeout
try:
try:
out = check_output(cmd, **kwargs)
except CalledProcessError as e:
raise RuntimeError("script errored with output:\n%s"
% e.output.decode('utf-8')) from e
if out != b"":
raise AssertionError(out.decode('utf-8'))
except TimeoutExpired as e:
raise RuntimeError("script timeout, output so far:\n%s"
% e.output.decode('utf-8')) from e
finally:
os.unlink(source_file)
if __name__ == '__main__':
protocol = int(sys.argv[sys.argv.index('--protocol') + 1])
pickle_echo(protocol=protocol)
srsly-release-v2.5.1/srsly/tests/msgpack/ 0000775 0000000 0000000 00000000000 14742310675 0020453 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/msgpack/__init__.py 0000664 0000000 0000000 00000000000 14742310675 0022552 0 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/msgpack/test_buffer.py 0000664 0000000 0000000 00000001305 14742310675 0023334 0 ustar 00root root 0000000 0000000 from srsly.msgpack import packb, unpackb
def test_unpack_buffer():
from array import array
buf = array("b")
buf.frombytes(packb((b"foo", b"bar")))
obj = unpackb(buf, use_list=1)
assert [b"foo", b"bar"] == obj
def test_unpack_bytearray():
buf = bytearray(packb(("foo", "bar")))
obj = unpackb(buf, use_list=1)
assert [b"foo", b"bar"] == obj
expected_type = bytes
assert all(type(s) == expected_type for s in obj)
def test_unpack_memoryview():
buf = bytearray(packb(("foo", "bar")))
view = memoryview(buf)
obj = unpackb(view, use_list=1)
assert [b"foo", b"bar"] == obj
expected_type = bytes
assert all(type(s) == expected_type for s in obj)
srsly-release-v2.5.1/srsly/tests/msgpack/test_case.py 0000664 0000000 0000000 00000005436 14742310675 0023007 0 ustar 00root root 0000000 0000000 from srsly.msgpack import packb, unpackb
def check(length, obj):
v = packb(obj)
assert len(v) == length, "%r length should be %r but get %r" % (obj, length, len(v))
assert unpackb(v, use_list=0) == obj
def test_1():
for o in [
None,
True,
False,
0,
1,
(1 << 6),
(1 << 7) - 1,
-1,
-((1 << 5) - 1),
-(1 << 5),
]:
check(1, o)
def test_2():
for o in [1 << 7, (1 << 8) - 1, -((1 << 5) + 1), -(1 << 7)]:
check(2, o)
def test_3():
for o in [1 << 8, (1 << 16) - 1, -((1 << 7) + 1), -(1 << 15)]:
check(3, o)
def test_5():
for o in [1 << 16, (1 << 32) - 1, -((1 << 15) + 1), -(1 << 31)]:
check(5, o)
def test_9():
for o in [
1 << 32,
(1 << 64) - 1,
-((1 << 31) + 1),
-(1 << 63),
1.0,
0.1,
-0.1,
-1.0,
]:
check(9, o)
def check_raw(overhead, num):
check(num + overhead, b" " * num)
def test_fixraw():
check_raw(1, 0)
check_raw(1, (1 << 5) - 1)
def test_raw16():
check_raw(3, 1 << 5)
check_raw(3, (1 << 16) - 1)
def test_raw32():
check_raw(5, 1 << 16)
def check_array(overhead, num):
check(num + overhead, (None,) * num)
def test_fixarray():
check_array(1, 0)
check_array(1, (1 << 4) - 1)
def test_array16():
check_array(3, 1 << 4)
check_array(3, (1 << 16) - 1)
def test_array32():
check_array(5, (1 << 16))
def match(obj, buf):
assert packb(obj) == buf
assert unpackb(buf, use_list=0) == obj
def test_match():
cases = [
(None, b"\xc0"),
(False, b"\xc2"),
(True, b"\xc3"),
(0, b"\x00"),
(127, b"\x7f"),
(128, b"\xcc\x80"),
(256, b"\xcd\x01\x00"),
(-1, b"\xff"),
(-33, b"\xd0\xdf"),
(-129, b"\xd1\xff\x7f"),
({1: 1}, b"\x81\x01\x01"),
(1.0, b"\xcb\x3f\xf0\x00\x00\x00\x00\x00\x00"),
((), b"\x90"),
(
tuple(range(15)),
b"\x9f\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e",
),
(
tuple(range(16)),
b"\xdc\x00\x10\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
),
({}, b"\x80"),
(
dict([(x, x) for x in range(15)]),
b"\x8f\x00\x00\x01\x01\x02\x02\x03\x03\x04\x04\x05\x05\x06\x06\x07\x07\x08\x08\t\t\n\n\x0b\x0b\x0c\x0c\r\r\x0e\x0e",
),
(
dict([(x, x) for x in range(16)]),
b"\xde\x00\x10\x00\x00\x01\x01\x02\x02\x03\x03\x04\x04\x05\x05\x06\x06\x07\x07\x08\x08\t\t\n\n\x0b\x0b\x0c\x0c\r\r\x0e\x0e\x0f\x0f",
),
]
for v, p in cases:
match(v, p)
def test_unicode():
assert unpackb(packb("foobar"), use_list=1) == b"foobar"
srsly-release-v2.5.1/srsly/tests/msgpack/test_except.py 0000664 0000000 0000000 00000003222 14742310675 0023353 0 ustar 00root root 0000000 0000000 from pytest import raises
import datetime
from srsly.msgpack import packb, unpackb, Unpacker, FormatError, StackError, OutOfData
class DummyException(Exception):
pass
def test_raise_on_find_unsupported_value():
with raises(TypeError):
packb(datetime.datetime.now())
def test_raise_from_object_hook():
def hook(obj):
raise DummyException
raises(DummyException, unpackb, packb({}), object_hook=hook)
raises(DummyException, unpackb, packb({"fizz": "buzz"}), object_hook=hook)
raises(DummyException, unpackb, packb({"fizz": "buzz"}), object_pairs_hook=hook)
raises(DummyException, unpackb, packb({"fizz": {"buzz": "spam"}}), object_hook=hook)
raises(
DummyException,
unpackb,
packb({"fizz": {"buzz": "spam"}}),
object_pairs_hook=hook,
)
def test_invalidvalue():
incomplete = b"\xd9\x97#DL_" # raw8 - length=0x97
with raises(ValueError):
unpackb(incomplete)
with raises(OutOfData):
unpacker = Unpacker()
unpacker.feed(incomplete)
unpacker.unpack()
with raises(FormatError):
unpackb(b"\xc1") # (undefined tag)
with raises(FormatError):
unpackb(b"\x91\xc1") # fixarray(len=1) [ (undefined tag) ]
with raises(StackError):
unpackb(b"\x91" * 3000) # nested fixarray(len=1)
def test_strict_map_key():
valid = {u"unicode": 1, b"bytes": 2}
packed = packb(valid, use_bin_type=True)
assert valid == unpackb(packed, raw=False, strict_map_key=True)
invalid = {42: 1}
packed = packb(invalid, use_bin_type=True)
with raises(ValueError):
unpackb(packed, raw=False, strict_map_key=True)
srsly-release-v2.5.1/srsly/tests/msgpack/test_extension.py 0000664 0000000 0000000 00000005011 14742310675 0024075 0 ustar 00root root 0000000 0000000 import array
from srsly import msgpack
from srsly.msgpack.ext import ExtType
def test_pack_ext_type():
def p(s):
packer = msgpack.Packer()
packer.pack_ext_type(0x42, s)
return packer.bytes()
assert p(b"A") == b"\xd4\x42A" # fixext 1
assert p(b"AB") == b"\xd5\x42AB" # fixext 2
assert p(b"ABCD") == b"\xd6\x42ABCD" # fixext 4
assert p(b"ABCDEFGH") == b"\xd7\x42ABCDEFGH" # fixext 8
assert p(b"A" * 16) == b"\xd8\x42" + b"A" * 16 # fixext 16
assert p(b"ABC") == b"\xc7\x03\x42ABC" # ext 8
assert p(b"A" * 0x0123) == b"\xc8\x01\x23\x42" + b"A" * 0x0123 # ext 16
assert (
p(b"A" * 0x00012345) == b"\xc9\x00\x01\x23\x45\x42" + b"A" * 0x00012345
) # ext 32
def test_unpack_ext_type():
def check(b, expected):
assert msgpack.unpackb(b) == expected
check(b"\xd4\x42A", ExtType(0x42, b"A")) # fixext 1
check(b"\xd5\x42AB", ExtType(0x42, b"AB")) # fixext 2
check(b"\xd6\x42ABCD", ExtType(0x42, b"ABCD")) # fixext 4
check(b"\xd7\x42ABCDEFGH", ExtType(0x42, b"ABCDEFGH")) # fixext 8
check(b"\xd8\x42" + b"A" * 16, ExtType(0x42, b"A" * 16)) # fixext 16
check(b"\xc7\x03\x42ABC", ExtType(0x42, b"ABC")) # ext 8
check(b"\xc8\x01\x23\x42" + b"A" * 0x0123, ExtType(0x42, b"A" * 0x0123)) # ext 16
check(
b"\xc9\x00\x01\x23\x45\x42" + b"A" * 0x00012345,
ExtType(0x42, b"A" * 0x00012345),
) # ext 32
def test_extension_type():
def default(obj):
print("default called", obj)
if isinstance(obj, array.array):
typecode = 123 # application specific typecode
data = obj.tobytes()
return ExtType(typecode, data)
raise TypeError("Unknown type object %r" % (obj,))
def ext_hook(code, data):
print("ext_hook called", code, data)
assert code == 123
obj = array.array("d")
obj.frombytes(data)
return obj
obj = [42, b"hello", array.array("d", [1.1, 2.2, 3.3])]
s = msgpack.packb(obj, default=default)
obj2 = msgpack.unpackb(s, ext_hook=ext_hook)
assert obj == obj2
def test_overriding_hooks():
def default(obj):
if isinstance(obj, int):
return {"__type__": "long", "__data__": str(obj)}
else:
return obj
obj = {"testval": int(1823746192837461928374619)}
refobj = {"testval": default(obj["testval"])}
refout = msgpack.packb(refobj)
assert isinstance(refout, (str, bytes))
testout = msgpack.packb(obj, default=default)
assert refout == testout
srsly-release-v2.5.1/srsly/tests/msgpack/test_format.py 0000664 0000000 0000000 00000004004 14742310675 0023352 0 ustar 00root root 0000000 0000000 from srsly.msgpack import unpackb
def check(src, should, use_list=0):
assert unpackb(src, use_list=use_list) == should
def testSimpleValue():
check(b"\x93\xc0\xc2\xc3", (None, False, True))
def testFixnum():
check(b"\x92\x93\x00\x40\x7f\x93\xe0\xf0\xff", ((0, 64, 127), (-32, -16, -1)))
def testFixArray():
check(b"\x92\x90\x91\x91\xc0", ((), ((None,),)))
def testFixRaw():
check(b"\x94\xa0\xa1a\xa2bc\xa3def", (b"", b"a", b"bc", b"def"))
def testFixMap():
check(
b"\x82\xc2\x81\xc0\xc0\xc3\x81\xc0\x80", {False: {None: None}, True: {None: {}}}
)
def testUnsignedInt():
check(
b"\x99\xcc\x00\xcc\x80\xcc\xff\xcd\x00\x00\xcd\x80\x00"
b"\xcd\xff\xff\xce\x00\x00\x00\x00\xce\x80\x00\x00\x00"
b"\xce\xff\xff\xff\xff",
(0, 128, 255, 0, 32768, 65535, 0, 2147483648, 4294967295),
)
def testSignedInt():
check(
b"\x99\xd0\x00\xd0\x80\xd0\xff\xd1\x00\x00\xd1\x80\x00"
b"\xd1\xff\xff\xd2\x00\x00\x00\x00\xd2\x80\x00\x00\x00"
b"\xd2\xff\xff\xff\xff",
(0, -128, -1, 0, -32768, -1, 0, -2147483648, -1),
)
def testRaw():
check(
b"\x96\xda\x00\x00\xda\x00\x01a\xda\x00\x02ab\xdb\x00\x00"
b"\x00\x00\xdb\x00\x00\x00\x01a\xdb\x00\x00\x00\x02ab",
(b"", b"a", b"ab", b"", b"a", b"ab"),
)
def testArray():
check(
b"\x96\xdc\x00\x00\xdc\x00\x01\xc0\xdc\x00\x02\xc2\xc3\xdd\x00"
b"\x00\x00\x00\xdd\x00\x00\x00\x01\xc0\xdd\x00\x00\x00\x02"
b"\xc2\xc3",
((), (None,), (False, True), (), (None,), (False, True)),
)
def testMap():
check(
b"\x96"
b"\xde\x00\x00"
b"\xde\x00\x01\xc0\xc2"
b"\xde\x00\x02\xc0\xc2\xc3\xc2"
b"\xdf\x00\x00\x00\x00"
b"\xdf\x00\x00\x00\x01\xc0\xc2"
b"\xdf\x00\x00\x00\x02\xc0\xc2\xc3\xc2",
(
{},
{None: False},
{True: False, None: False},
{},
{None: False},
{True: False, None: False},
),
)
srsly-release-v2.5.1/srsly/tests/msgpack/test_limits.py 0000664 0000000 0000000 00000006063 14742310675 0023372 0 ustar 00root root 0000000 0000000 import pytest
from srsly.msgpack import packb, unpackb, Packer, Unpacker, ExtType
from srsly.msgpack import PackOverflowError, PackValueError, UnpackValueError
def test_integer():
x = -(2 ** 63)
assert unpackb(packb(x)) == x
with pytest.raises(PackOverflowError):
packb(x - 1)
x = 2 ** 64 - 1
assert unpackb(packb(x)) == x
with pytest.raises(PackOverflowError):
packb(x + 1)
def test_array_header():
packer = Packer()
packer.pack_array_header(2 ** 32 - 1)
with pytest.raises(PackValueError):
packer.pack_array_header(2 ** 32)
def test_map_header():
packer = Packer()
packer.pack_map_header(2 ** 32 - 1)
with pytest.raises(PackValueError):
packer.pack_array_header(2 ** 32)
def test_max_str_len():
d = "x" * 3
packed = packb(d)
unpacker = Unpacker(max_str_len=3, raw=False)
unpacker.feed(packed)
assert unpacker.unpack() == d
unpacker = Unpacker(max_str_len=2, raw=False)
with pytest.raises(UnpackValueError):
unpacker.feed(packed)
unpacker.unpack()
def test_max_bin_len():
d = b"x" * 3
packed = packb(d, use_bin_type=True)
unpacker = Unpacker(max_bin_len=3)
unpacker.feed(packed)
assert unpacker.unpack() == d
unpacker = Unpacker(max_bin_len=2)
with pytest.raises(UnpackValueError):
unpacker.feed(packed)
unpacker.unpack()
def test_max_array_len():
d = [1, 2, 3]
packed = packb(d)
unpacker = Unpacker(max_array_len=3)
unpacker.feed(packed)
assert unpacker.unpack() == d
unpacker = Unpacker(max_array_len=2)
with pytest.raises(UnpackValueError):
unpacker.feed(packed)
unpacker.unpack()
def test_max_map_len():
d = {1: 2, 3: 4, 5: 6}
packed = packb(d)
unpacker = Unpacker(max_map_len=3)
unpacker.feed(packed)
assert unpacker.unpack() == d
unpacker = Unpacker(max_map_len=2)
with pytest.raises(UnpackValueError):
unpacker.feed(packed)
unpacker.unpack()
def test_max_ext_len():
d = ExtType(42, b"abc")
packed = packb(d)
unpacker = Unpacker(max_ext_len=3)
unpacker.feed(packed)
assert unpacker.unpack() == d
unpacker = Unpacker(max_ext_len=2)
with pytest.raises(UnpackValueError):
unpacker.feed(packed)
unpacker.unpack()
# PyPy fails following tests because of constant folding?
# https://bugs.pypy.org/issue1721
# @pytest.mark.skipif(True, reason="Requires very large memory.")
# def test_binary():
# x = b'x' * (2**32 - 1)
# assert unpackb(packb(x)) == x
# del x
# x = b'x' * (2**32)
# with pytest.raises(ValueError):
# packb(x)
#
#
# @pytest.mark.skipif(True, reason="Requires very large memory.")
# def test_string():
# x = 'x' * (2**32 - 1)
# assert unpackb(packb(x)) == x
# x += 'y'
# with pytest.raises(ValueError):
# packb(x)
#
#
# @pytest.mark.skipif(True, reason="Requires very large memory.")
# def test_array():
# x = [0] * (2**32 - 1)
# assert unpackb(packb(x)) == x
# x.append(0)
# with pytest.raises(ValueError):
# packb(x)
srsly-release-v2.5.1/srsly/tests/msgpack/test_memoryview.py 0000664 0000000 0000000 00000005025 14742310675 0024271 0 ustar 00root root 0000000 0000000 from array import array
from srsly.msgpack import packb, unpackb
make_memoryview = memoryview
def make_array(f, data):
a = array(f)
a.frombytes(data)
return a
def get_data(a):
return a.tobytes()
def _runtest(format, nbytes, expected_header, expected_prefix, use_bin_type):
# create a new array
original_array = array(format)
original_array.fromlist([255] * (nbytes // original_array.itemsize))
original_data = get_data(original_array)
view = make_memoryview(original_array)
# pack, unpack, and reconstruct array
packed = packb(view, use_bin_type=use_bin_type)
unpacked = unpackb(packed)
reconstructed_array = make_array(format, unpacked)
# check that we got the right amount of data
assert len(original_data) == nbytes
# check packed header
assert packed[:1] == expected_header
# check packed length prefix, if any
assert packed[1 : 1 + len(expected_prefix)] == expected_prefix
# check packed data
assert packed[1 + len(expected_prefix) :] == original_data
# check array unpacked correctly
assert original_array == reconstructed_array
def test_fixstr_from_byte():
_runtest("B", 1, b"\xa1", b"", False)
_runtest("B", 31, b"\xbf", b"", False)
def test_fixstr_from_float():
_runtest("f", 4, b"\xa4", b"", False)
_runtest("f", 28, b"\xbc", b"", False)
def test_str16_from_byte():
_runtest("B", 2 ** 8, b"\xda", b"\x01\x00", False)
_runtest("B", 2 ** 16 - 1, b"\xda", b"\xff\xff", False)
def test_str16_from_float():
_runtest("f", 2 ** 8, b"\xda", b"\x01\x00", False)
_runtest("f", 2 ** 16 - 4, b"\xda", b"\xff\xfc", False)
def test_str32_from_byte():
_runtest("B", 2 ** 16, b"\xdb", b"\x00\x01\x00\x00", False)
def test_str32_from_float():
_runtest("f", 2 ** 16, b"\xdb", b"\x00\x01\x00\x00", False)
def test_bin8_from_byte():
_runtest("B", 1, b"\xc4", b"\x01", True)
_runtest("B", 2 ** 8 - 1, b"\xc4", b"\xff", True)
def test_bin8_from_float():
_runtest("f", 4, b"\xc4", b"\x04", True)
_runtest("f", 2 ** 8 - 4, b"\xc4", b"\xfc", True)
def test_bin16_from_byte():
_runtest("B", 2 ** 8, b"\xc5", b"\x01\x00", True)
_runtest("B", 2 ** 16 - 1, b"\xc5", b"\xff\xff", True)
def test_bin16_from_float():
_runtest("f", 2 ** 8, b"\xc5", b"\x01\x00", True)
_runtest("f", 2 ** 16 - 4, b"\xc5", b"\xff\xfc", True)
def test_bin32_from_byte():
_runtest("B", 2 ** 16, b"\xc6", b"\x00\x01\x00\x00", True)
def test_bin32_from_float():
_runtest("f", 2 ** 16, b"\xc6", b"\x00\x01\x00\x00", True)
srsly-release-v2.5.1/srsly/tests/msgpack/test_newspec.py 0000664 0000000 0000000 00000005054 14742310675 0023534 0 ustar 00root root 0000000 0000000 from srsly.msgpack import packb, unpackb, ExtType
def test_str8():
header = b"\xd9"
data = b"x" * 32
b = packb(data.decode(), use_bin_type=True)
assert len(b) == len(data) + 2
assert b[0:2] == header + b"\x20"
assert b[2:] == data
assert unpackb(b) == data
data = b"x" * 255
b = packb(data.decode(), use_bin_type=True)
assert len(b) == len(data) + 2
assert b[0:2] == header + b"\xff"
assert b[2:] == data
assert unpackb(b) == data
def test_bin8():
header = b"\xc4"
data = b""
b = packb(data, use_bin_type=True)
assert len(b) == len(data) + 2
assert b[0:2] == header + b"\x00"
assert b[2:] == data
assert unpackb(b) == data
data = b"x" * 255
b = packb(data, use_bin_type=True)
assert len(b) == len(data) + 2
assert b[0:2] == header + b"\xff"
assert b[2:] == data
assert unpackb(b) == data
def test_bin16():
header = b"\xc5"
data = b"x" * 256
b = packb(data, use_bin_type=True)
assert len(b) == len(data) + 3
assert b[0:1] == header
assert b[1:3] == b"\x01\x00"
assert b[3:] == data
assert unpackb(b) == data
data = b"x" * 65535
b = packb(data, use_bin_type=True)
assert len(b) == len(data) + 3
assert b[0:1] == header
assert b[1:3] == b"\xff\xff"
assert b[3:] == data
assert unpackb(b) == data
def test_bin32():
header = b"\xc6"
data = b"x" * 65536
b = packb(data, use_bin_type=True)
assert len(b) == len(data) + 5
assert b[0:1] == header
assert b[1:5] == b"\x00\x01\x00\x00"
assert b[5:] == data
assert unpackb(b) == data
def test_ext():
def check(ext, packed):
assert packb(ext) == packed
assert unpackb(packed) == ext
check(ExtType(0x42, b"Z"), b"\xd4\x42Z") # fixext 1
check(ExtType(0x42, b"ZZ"), b"\xd5\x42ZZ") # fixext 2
check(ExtType(0x42, b"Z" * 4), b"\xd6\x42" + b"Z" * 4) # fixext 4
check(ExtType(0x42, b"Z" * 8), b"\xd7\x42" + b"Z" * 8) # fixext 8
check(ExtType(0x42, b"Z" * 16), b"\xd8\x42" + b"Z" * 16) # fixext 16
# ext 8
check(ExtType(0x42, b""), b"\xc7\x00\x42")
check(ExtType(0x42, b"Z" * 255), b"\xc7\xff\x42" + b"Z" * 255)
# ext 16
check(ExtType(0x42, b"Z" * 256), b"\xc8\x01\x00\x42" + b"Z" * 256)
check(ExtType(0x42, b"Z" * 0xFFFF), b"\xc8\xff\xff\x42" + b"Z" * 0xFFFF)
# ext 32
check(ExtType(0x42, b"Z" * 0x10000), b"\xc9\x00\x01\x00\x00\x42" + b"Z" * 0x10000)
# needs large memory
# check(ExtType(0x42, b'Z'*0xffffffff),
# b'\xc9\xff\xff\xff\xff\x42' + b'Z'*0xffffffff)
srsly-release-v2.5.1/srsly/tests/msgpack/test_numpy.py 0000664 0000000 0000000 00000020641 14742310675 0023237 0 ustar 00root root 0000000 0000000 from unittest import TestCase
from numpy.testing import assert_equal, assert_array_equal
import numpy as np
from srsly import msgpack
class ThirdParty(object):
def __init__(self, foo=b"bar"):
self.foo = foo
def __eq__(self, other):
return isinstance(other, ThirdParty) and self.foo == other.foo
class test_numpy_msgpack(TestCase):
def encode_decode(self, x, use_bin_type=False, raw=True):
x_enc = msgpack.packb(x, use_bin_type=use_bin_type)
return msgpack.unpackb(x_enc, raw=raw)
def encode_thirdparty(self, obj):
return dict(__thirdparty__=True, foo=obj.foo)
def decode_thirdparty(self, obj):
if b"__thirdparty__" in obj:
return ThirdParty(foo=obj[b"foo"])
return obj
def encode_decode_thirdparty(self, x, use_bin_type=False, raw=True):
x_enc = msgpack.packb(
x, default=self.encode_thirdparty, use_bin_type=use_bin_type
)
return msgpack.unpackb(x_enc, raw=raw, object_hook=self.decode_thirdparty)
def test_bin(self):
# Since bytes == str in Python 2.7, the following
# should pass on both 2.7 and 3.*
assert_equal(type(self.encode_decode(b"foo")), bytes)
def test_str(self):
assert_equal(type(self.encode_decode("foo")), bytes)
def test_numpy_scalar_bool(self):
x = np.bool_(True)
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
x = np.bool_(False)
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_numpy_scalar_float(self):
x = np.float32(np.random.rand())
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_numpy_scalar_complex(self):
x = np.complex64(np.random.rand() + 1j * np.random.rand())
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_scalar_float(self):
x = np.random.rand()
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_scalar_complex(self):
x = np.random.rand() + 1j * np.random.rand()
x_rec = self.encode_decode(x)
assert_equal(x, x_rec)
assert_equal(type(x), type(x_rec))
def test_list_numpy_float(self):
x = [np.float32(np.random.rand()) for i in range(5)]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
def test_list_numpy_float_complex(self):
x = [np.float32(np.random.rand()) for i in range(5)] + [
np.complex128(np.random.rand() + 1j * np.random.rand()) for i in range(5)
]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
def test_list_float(self):
x = [np.random.rand() for i in range(5)]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
def test_list_float_complex(self):
x = [(np.random.rand() + 1j * np.random.rand()) for i in range(5)]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
def test_list_str(self):
x = [b"x" * i for i in range(5)]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x_rec], [bytes] * 5)
def test_dict_float(self):
x = {b"foo": 1.0, b"bar": 2.0}
x_rec = self.encode_decode(x)
assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
assert_array_equal(
[type(e) for e in sorted(x.values())],
[type(e) for e in sorted(x_rec.values())],
)
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal(
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
)
def test_dict_complex(self):
x = {b"foo": 1.0 + 1.0j, b"bar": 2.0 + 2.0j}
x_rec = self.encode_decode(x)
assert_array_equal(
sorted(x.values(), key=np.linalg.norm),
sorted(x_rec.values(), key=np.linalg.norm),
)
assert_array_equal(
[type(e) for e in sorted(x.values(), key=np.linalg.norm)],
[type(e) for e in sorted(x_rec.values(), key=np.linalg.norm)],
)
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal(
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
)
def test_dict_str(self):
x = {b"foo": b"xxx", b"bar": b"yyyy"}
x_rec = self.encode_decode(x)
assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
assert_array_equal(
[type(e) for e in sorted(x.values())],
[type(e) for e in sorted(x_rec.values())],
)
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal(
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
)
def test_dict_numpy_float(self):
x = {b"foo": np.float32(1.0), b"bar": np.float32(2.0)}
x_rec = self.encode_decode(x)
assert_array_equal(sorted(x.values()), sorted(x_rec.values()))
assert_array_equal(
[type(e) for e in sorted(x.values())],
[type(e) for e in sorted(x_rec.values())],
)
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal(
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
)
def test_dict_numpy_complex(self):
x = {b"foo": np.complex128(1.0 + 1.0j), b"bar": np.complex128(2.0 + 2.0j)}
x_rec = self.encode_decode(x)
assert_array_equal(
sorted(x.values(), key=np.linalg.norm),
sorted(x_rec.values(), key=np.linalg.norm),
)
assert_array_equal(
[type(e) for e in sorted(x.values(), key=np.linalg.norm)],
[type(e) for e in sorted(x_rec.values(), key=np.linalg.norm)],
)
assert_array_equal(sorted(x.keys()), sorted(x_rec.keys()))
assert_array_equal(
[type(e) for e in sorted(x.keys())], [type(e) for e in sorted(x_rec.keys())]
)
def test_numpy_array_float(self):
x = np.random.rand(5).astype(np.float32)
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_numpy_array_complex(self):
x = (np.random.rand(5) + 1j * np.random.rand(5)).astype(np.complex128)
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_numpy_array_float_2d(self):
x = np.random.rand(5, 5).astype(np.float32)
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_numpy_array_str(self):
x = np.array([b"aaa", b"bbbb", b"ccccc"])
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_numpy_array_mixed(self):
x = np.array(
[(1, 2, b"a", [1.0, 2.0])],
np.dtype(
[
("arg0", np.uint32),
("arg1", np.uint32),
("arg2", "S1"),
("arg3", np.float32, (2,)),
]
),
)
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_numpy_array_noncontiguous(self):
x = np.ones((10, 10), np.uint32)[0:5, 0:5]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_equal(x.dtype, x_rec.dtype)
def test_list_mixed(self):
x = [1.0, np.float32(3.5), np.complex128(4.25), b"foo"]
x_rec = self.encode_decode(x)
assert_array_equal(x, x_rec)
assert_array_equal([type(e) for e in x], [type(e) for e in x_rec])
def test_chain(self):
x = ThirdParty(foo=b"test marshal/unmarshal")
x_rec = self.encode_decode_thirdparty(x)
self.assertEqual(x, x_rec)
srsly-release-v2.5.1/srsly/tests/msgpack/test_pack.py 0000664 0000000 0000000 00000011717 14742310675 0023011 0 ustar 00root root 0000000 0000000 import struct
import pytest
from collections import OrderedDict
from io import BytesIO
from srsly.msgpack import packb, unpackb, Unpacker, Packer
def check(data, use_list=False):
re = unpackb(packb(data), use_list=use_list)
assert re == data
def testPack():
test_data = [
0,
1,
127,
128,
255,
256,
65535,
65536,
4294967295,
4294967296,
-1,
-32,
-33,
-128,
-129,
-32768,
-32769,
-4294967296,
-4294967297,
1.0,
b"",
b"a",
b"a" * 31,
b"a" * 32,
None,
True,
False,
(),
((),),
((), None),
{None: 0},
(1 << 23),
]
for td in test_data:
check(td)
def testPackUnicode():
test_data = ["", "abcd", ["defgh"], "Русский текст"]
for td in test_data:
re = unpackb(packb(td), use_list=1, raw=False)
assert re == td
packer = Packer()
data = packer.pack(td)
re = Unpacker(BytesIO(data), raw=False, use_list=1).unpack()
assert re == td
def testPackUTF32(): # deprecated
re = unpackb(packb("", encoding="utf-32"), use_list=1, encoding="utf-32")
assert re == ""
re = unpackb(packb("abcd", encoding="utf-32"), use_list=1, encoding="utf-32")
assert re == "abcd"
re = unpackb(packb(["defgh"], encoding="utf-32"), use_list=1, encoding="utf-32")
assert re == ["defgh"]
try:
packb("Русский текст", encoding="utf-32")
except LookupError as e:
pytest.xfail(str(e))
# try:
# test_data = ["", "abcd", ["defgh"], "Русский текст"]
# for td in test_data:
# except LookupError as e:
# pytest.xfail(e)
def testPackBytes():
test_data = [b"", b"abcd", (b"defgh",)]
for td in test_data:
check(td)
def testPackByteArrays():
test_data = [bytearray(b""), bytearray(b"abcd"), (bytearray(b"defgh"),)]
for td in test_data:
check(td)
def testIgnoreUnicodeErrors(): # deprecated
re = unpackb(
packb(b"abc\xeddef"), encoding="utf-8", unicode_errors="ignore", use_list=1
)
assert re == "abcdef"
def testStrictUnicodeUnpack():
with pytest.raises(UnicodeDecodeError):
unpackb(packb(b"abc\xeddef"), raw=False, use_list=1)
def testStrictUnicodePack(): # deprecated
with pytest.raises(UnicodeEncodeError):
packb("abc\xeddef", encoding="ascii", unicode_errors="strict")
def testIgnoreErrorsPack(): # deprecated
re = unpackb(
packb("abcФФФdef", encoding="ascii", unicode_errors="ignore"),
raw=False,
use_list=1,
)
assert re == "abcdef"
def testDecodeBinary():
re = unpackb(packb(b"abc"), encoding=None, use_list=1)
assert re == b"abc"
def testPackFloat():
assert packb(1.0, use_single_float=True) == b"\xca" + struct.pack(str(">f"), 1.0)
assert packb(1.0, use_single_float=False) == b"\xcb" + struct.pack(str(">d"), 1.0)
def testArraySize(sizes=[0, 5, 50, 1000]):
bio = BytesIO()
packer = Packer()
for size in sizes:
bio.write(packer.pack_array_header(size))
for i in range(size):
bio.write(packer.pack(i))
bio.seek(0)
unpacker = Unpacker(bio, use_list=1)
for size in sizes:
assert unpacker.unpack() == list(range(size))
def test_manualreset(sizes=[0, 5, 50, 1000]):
packer = Packer(autoreset=False)
for size in sizes:
packer.pack_array_header(size)
for i in range(size):
packer.pack(i)
bio = BytesIO(packer.bytes())
unpacker = Unpacker(bio, use_list=1)
for size in sizes:
assert unpacker.unpack() == list(range(size))
packer.reset()
assert packer.bytes() == b""
def testMapSize(sizes=[0, 5, 50, 1000]):
bio = BytesIO()
packer = Packer()
for size in sizes:
bio.write(packer.pack_map_header(size))
for i in range(size):
bio.write(packer.pack(i)) # key
bio.write(packer.pack(i * 2)) # value
bio.seek(0)
unpacker = Unpacker(bio)
for size in sizes:
assert unpacker.unpack() == dict((i, i * 2) for i in range(size))
def test_odict():
seq = [(b"one", 1), (b"two", 2), (b"three", 3), (b"four", 4)]
od = OrderedDict(seq)
assert unpackb(packb(od), use_list=1) == dict(seq)
def pair_hook(seq):
return list(seq)
assert unpackb(packb(od), object_pairs_hook=pair_hook, use_list=1) == seq
def test_pairlist():
pairlist = [(b"a", 1), (2, b"b"), (b"foo", b"bar")]
packer = Packer()
packed = packer.pack_map_pairs(pairlist)
unpacked = unpackb(packed, object_pairs_hook=list)
assert pairlist == unpacked
def test_get_buffer():
packer = Packer(autoreset=0, use_bin_type=True)
packer.pack([1, 2])
strm = BytesIO()
strm.write(packer.getbuffer())
written = strm.getvalue()
expected = packb([1, 2], use_bin_type=True)
assert written == expected
srsly-release-v2.5.1/srsly/tests/msgpack/test_read_size.py 0000664 0000000 0000000 00000003410 14742310675 0024027 0 ustar 00root root 0000000 0000000 from srsly.msgpack import packb, Unpacker, OutOfData
UnexpectedTypeException = ValueError
def test_read_array_header():
unpacker = Unpacker()
unpacker.feed(packb(["a", "b", "c"]))
assert unpacker.read_array_header() == 3
assert unpacker.unpack() == b"a"
assert unpacker.unpack() == b"b"
assert unpacker.unpack() == b"c"
try:
unpacker.unpack()
assert 0, "should raise exception"
except OutOfData:
assert 1, "okay"
def test_read_map_header():
unpacker = Unpacker()
unpacker.feed(packb({"a": "A"}))
assert unpacker.read_map_header() == 1
assert unpacker.unpack() == b"a"
assert unpacker.unpack() == b"A"
try:
unpacker.unpack()
assert 0, "should raise exception"
except OutOfData:
assert 1, "okay"
def test_incorrect_type_array():
unpacker = Unpacker()
unpacker.feed(packb(1))
try:
unpacker.read_array_header()
assert 0, "should raise exception"
except UnexpectedTypeException:
assert 1, "okay"
def test_incorrect_type_map():
unpacker = Unpacker()
unpacker.feed(packb(1))
try:
unpacker.read_map_header()
assert 0, "should raise exception"
except UnexpectedTypeException:
assert 1, "okay"
def test_correct_type_nested_array():
unpacker = Unpacker()
unpacker.feed(packb({"a": ["b", "c", "d"]}))
try:
unpacker.read_array_header()
assert 0, "should raise exception"
except UnexpectedTypeException:
assert 1, "okay"
def test_incorrect_type_nested_map():
unpacker = Unpacker()
unpacker.feed(packb([{"a": "b"}]))
try:
unpacker.read_map_header()
assert 0, "should raise exception"
except UnexpectedTypeException:
assert 1, "okay"
srsly-release-v2.5.1/srsly/tests/msgpack/test_seq.py 0000664 0000000 0000000 00000002145 14742310675 0022656 0 ustar 00root root 0000000 0000000 import io
from srsly import msgpack
binarydata = bytes(bytearray(range(256)))
def gen_binary_data(idx):
return binarydata[: idx % 300]
def test_exceeding_unpacker_read_size():
dumpf = io.BytesIO()
packer = msgpack.Packer()
NUMBER_OF_STRINGS = 6
read_size = 16
# 5 ok for read_size=16, while 6 glibc detected *** python: double free or corruption (fasttop):
# 20 ok for read_size=256, while 25 segfaults / glibc detected *** python: double free or corruption (!prev)
# 40 ok for read_size=1024, while 50 introduces errors
# 7000 ok for read_size=1024*1024, while 8000 leads to glibc detected *** python: double free or corruption (!prev):
for idx in range(NUMBER_OF_STRINGS):
data = gen_binary_data(idx)
dumpf.write(packer.pack(data))
f = io.BytesIO(dumpf.getvalue())
dumpf.close()
unpacker = msgpack.Unpacker(f, read_size=read_size, use_list=1)
read_count = 0
for idx, o in enumerate(unpacker):
assert type(o) == bytes
assert o == gen_binary_data(idx)
read_count += 1
assert read_count == NUMBER_OF_STRINGS
srsly-release-v2.5.1/srsly/tests/msgpack/test_sequnpack.py 0000664 0000000 0000000 00000007060 14742310675 0024061 0 ustar 00root root 0000000 0000000 import io
import pytest
from srsly.msgpack import Unpacker, BufferFull
from srsly.msgpack import pack
from srsly.msgpack.exceptions import OutOfData
def test_partialdata():
unpacker = Unpacker()
unpacker.feed(b"\xa5")
with pytest.raises(StopIteration):
next(iter(unpacker))
unpacker.feed(b"h")
with pytest.raises(StopIteration):
next(iter(unpacker))
unpacker.feed(b"a")
with pytest.raises(StopIteration):
next(iter(unpacker))
unpacker.feed(b"l")
with pytest.raises(StopIteration):
next(iter(unpacker))
unpacker.feed(b"l")
with pytest.raises(StopIteration):
next(iter(unpacker))
unpacker.feed(b"o")
assert next(iter(unpacker)) == b"hallo"
def test_foobar():
unpacker = Unpacker(read_size=3, use_list=1)
unpacker.feed(b"foobar")
assert unpacker.unpack() == ord(b"f")
assert unpacker.unpack() == ord(b"o")
assert unpacker.unpack() == ord(b"o")
assert unpacker.unpack() == ord(b"b")
assert unpacker.unpack() == ord(b"a")
assert unpacker.unpack() == ord(b"r")
with pytest.raises(OutOfData):
unpacker.unpack()
unpacker.feed(b"foo")
unpacker.feed(b"bar")
k = 0
for o, e in zip(unpacker, "foobarbaz"):
assert o == ord(e)
k += 1
assert k == len(b"foobar")
def test_foobar_skip():
unpacker = Unpacker(read_size=3, use_list=1)
unpacker.feed(b"foobar")
assert unpacker.unpack() == ord(b"f")
unpacker.skip()
assert unpacker.unpack() == ord(b"o")
unpacker.skip()
assert unpacker.unpack() == ord(b"a")
unpacker.skip()
with pytest.raises(OutOfData):
unpacker.unpack()
def test_maxbuffersize():
with pytest.raises(ValueError):
Unpacker(read_size=5, max_buffer_size=3)
unpacker = Unpacker(read_size=3, max_buffer_size=3, use_list=1)
unpacker.feed(b"fo")
with pytest.raises(BufferFull):
unpacker.feed(b"ob")
unpacker.feed(b"o")
assert ord("f") == next(unpacker)
unpacker.feed(b"b")
assert ord("o") == next(unpacker)
assert ord("o") == next(unpacker)
assert ord("b") == next(unpacker)
def test_readbytes():
unpacker = Unpacker(read_size=3)
unpacker.feed(b"foobar")
assert unpacker.unpack() == ord(b"f")
assert unpacker.read_bytes(3) == b"oob"
assert unpacker.unpack() == ord(b"a")
assert unpacker.unpack() == ord(b"r")
# Test buffer refill
unpacker = Unpacker(io.BytesIO(b"foobar"), read_size=3)
assert unpacker.unpack() == ord(b"f")
assert unpacker.read_bytes(3) == b"oob"
assert unpacker.unpack() == ord(b"a")
assert unpacker.unpack() == ord(b"r")
def test_issue124():
unpacker = Unpacker()
unpacker.feed(b"\xa1?\xa1!")
assert tuple(unpacker) == (b"?", b"!")
assert tuple(unpacker) == ()
unpacker.feed(b"\xa1?\xa1")
assert tuple(unpacker) == (b"?",)
assert tuple(unpacker) == ()
unpacker.feed(b"!")
assert tuple(unpacker) == (b"!",)
assert tuple(unpacker) == ()
def test_unpack_tell():
stream = io.BytesIO()
messages = [2 ** i - 1 for i in range(65)]
messages += [-(2 ** i) for i in range(1, 64)]
messages += [
b"hello",
b"hello" * 1000,
list(range(20)),
{i: bytes(i) * i for i in range(10)},
{i: bytes(i) * i for i in range(32)},
]
offsets = []
for m in messages:
pack(m, stream)
offsets.append(stream.tell())
stream.seek(0)
unpacker = Unpacker(stream)
for m, o in zip(messages, offsets):
m2 = next(unpacker)
assert m == m2
assert o == unpacker.tell()
srsly-release-v2.5.1/srsly/tests/msgpack/test_stricttype.py 0000664 0000000 0000000 00000003365 14742310675 0024305 0 ustar 00root root 0000000 0000000 from collections import namedtuple
from srsly.msgpack import packb, unpackb, ExtType
def test_namedtuple():
T = namedtuple("T", "foo bar")
def default(o):
if isinstance(o, T):
return dict(o._asdict())
raise TypeError("Unsupported type %s" % (type(o),))
packed = packb(T(1, 42), strict_types=True, use_bin_type=True, default=default)
unpacked = unpackb(packed, raw=False)
assert unpacked == {"foo": 1, "bar": 42}
def test_tuple():
t = ("one", 2, b"three", (4,))
def default(o):
if isinstance(o, tuple):
return {"__type__": "tuple", "value": list(o)}
raise TypeError("Unsupported type %s" % (type(o),))
def convert(o):
if o.get("__type__") == "tuple":
return tuple(o["value"])
return o
data = packb(t, strict_types=True, use_bin_type=True, default=default)
expected = unpackb(data, raw=False, object_hook=convert)
assert expected == t
def test_tuple_ext():
t = ("one", 2, b"three", (4,))
MSGPACK_EXT_TYPE_TUPLE = 0
def default(o):
if isinstance(o, tuple):
# Convert to list and pack
payload = packb(
list(o), strict_types=True, use_bin_type=True, default=default
)
return ExtType(MSGPACK_EXT_TYPE_TUPLE, payload)
raise TypeError(repr(o))
def convert(code, payload):
if code == MSGPACK_EXT_TYPE_TUPLE:
# Unpack and convert to tuple
return tuple(unpackb(payload, raw=False, ext_hook=convert))
raise ValueError("Unknown Ext code {}".format(code))
data = packb(t, strict_types=True, use_bin_type=True, default=default)
expected = unpackb(data, raw=False, ext_hook=convert)
assert expected == t
srsly-release-v2.5.1/srsly/tests/msgpack/test_subtype.py 0000664 0000000 0000000 00000000567 14742310675 0023567 0 ustar 00root root 0000000 0000000 from collections import namedtuple
from srsly.msgpack import packb
class MyList(list):
pass
class MyDict(dict):
pass
class MyTuple(tuple):
pass
MyNamedTuple = namedtuple("MyNamedTuple", "x y")
def test_types():
assert packb(MyDict()) == packb(dict())
assert packb(MyList()) == packb(list())
assert packb(MyNamedTuple(1, 2)) == packb((1, 2))
srsly-release-v2.5.1/srsly/tests/msgpack/test_unpack.py 0000664 0000000 0000000 00000003434 14742310675 0023351 0 ustar 00root root 0000000 0000000 from io import BytesIO
import sys
import pytest
from srsly.msgpack import Unpacker, packb, OutOfData, ExtType
def test_unpack_array_header_from_file():
f = BytesIO(packb([1, 2, 3, 4]))
unpacker = Unpacker(f)
assert unpacker.read_array_header() == 4
assert unpacker.unpack() == 1
assert unpacker.unpack() == 2
assert unpacker.unpack() == 3
assert unpacker.unpack() == 4
with pytest.raises(OutOfData):
unpacker.unpack()
@pytest.mark.skipif(
"not hasattr(sys, 'getrefcount') == True",
reason="sys.getrefcount() is needed to pass this test",
)
def test_unpacker_hook_refcnt():
result = []
def hook(x):
result.append(x)
return x
basecnt = sys.getrefcount(hook)
up = Unpacker(object_hook=hook, list_hook=hook)
assert sys.getrefcount(hook) >= basecnt + 2
up.feed(packb([{}]))
up.feed(packb([{}]))
assert up.unpack() == [{}]
assert up.unpack() == [{}]
assert result == [{}, [{}], {}, [{}]]
del up
assert sys.getrefcount(hook) == basecnt
def test_unpacker_ext_hook():
class MyUnpacker(Unpacker):
def __init__(self):
super(MyUnpacker, self).__init__(ext_hook=self._hook, raw=False)
def _hook(self, code, data):
if code == 1:
return int(data)
else:
return ExtType(code, data)
unpacker = MyUnpacker()
unpacker.feed(packb({"a": 1}))
assert unpacker.unpack() == {"a": 1}
unpacker.feed(packb({"a": ExtType(1, b"123")}))
assert unpacker.unpack() == {"a": 123}
unpacker.feed(packb({"a": ExtType(2, b"321")}))
assert unpacker.unpack() == {"a": ExtType(2, b"321")}
if __name__ == "__main__":
test_unpack_array_header_from_file()
test_unpacker_hook_refcnt()
test_unpacker_ext_hook()
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/ 0000775 0000000 0000000 00000000000 14742310675 0021335 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/ruamel_yaml/__init__.py 0000664 0000000 0000000 00000000000 14742310675 0023434 0 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/ruamel_yaml/roundtrip.py 0000775 0000000 0000000 00000022422 14742310675 0023742 0 ustar 00root root 0000000 0000000 """
helper routines for testing round trip of commented YAML data
"""
import sys
import textwrap
from pathlib import Path
enforce = object()
def dedent(data):
try:
position_of_first_newline = data.index("\n")
for idx in range(position_of_first_newline):
if not data[idx].isspace():
raise ValueError
except ValueError:
pass
else:
data = data[position_of_first_newline + 1 :]
return textwrap.dedent(data)
def round_trip_load(inp, preserve_quotes=None, version=None):
import srsly.ruamel_yaml # NOQA
dinp = dedent(inp)
return srsly.ruamel_yaml.load(
dinp,
Loader=srsly.ruamel_yaml.RoundTripLoader,
preserve_quotes=preserve_quotes,
version=version,
)
def round_trip_load_all(inp, preserve_quotes=None, version=None):
import srsly.ruamel_yaml # NOQA
dinp = dedent(inp)
return srsly.ruamel_yaml.load_all(
dinp,
Loader=srsly.ruamel_yaml.RoundTripLoader,
preserve_quotes=preserve_quotes,
version=version,
)
def round_trip_dump(
data,
stream=None,
indent=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
explicit_start=None,
explicit_end=None,
version=None,
):
import srsly.ruamel_yaml # NOQA
return srsly.ruamel_yaml.round_trip_dump(
data,
stream=stream,
indent=indent,
block_seq_indent=block_seq_indent,
top_level_colon_align=top_level_colon_align,
prefix_colon=prefix_colon,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
)
def diff(inp, outp, file_name="stdin"):
import difflib
inl = inp.splitlines(True) # True for keepends
outl = outp.splitlines(True)
diff = difflib.unified_diff(inl, outl, file_name, "round trip YAML")
# 2.6 difflib has trailing space on filename lines %-)
strip_trailing_space = sys.version_info < (2, 7)
for line in diff:
if strip_trailing_space and line[:4] in ["--- ", "+++ "]:
line = line.rstrip() + "\n"
sys.stdout.write(line)
def round_trip(
inp,
outp=None,
extra=None,
intermediate=None,
indent=None,
block_seq_indent=None,
top_level_colon_align=None,
prefix_colon=None,
preserve_quotes=None,
explicit_start=None,
explicit_end=None,
version=None,
dump_data=None,
):
"""
inp: input string to parse
outp: expected output (equals input if not specified)
"""
if outp is None:
outp = inp
doutp = dedent(outp)
if extra is not None:
doutp += extra
data = round_trip_load(inp, preserve_quotes=preserve_quotes)
if dump_data:
print("data", data)
if intermediate is not None:
if isinstance(intermediate, dict):
for k, v in intermediate.items():
if data[k] != v:
print("{0!r} <> {1!r}".format(data[k], v))
raise ValueError
res = round_trip_dump(
data,
indent=indent,
block_seq_indent=block_seq_indent,
top_level_colon_align=top_level_colon_align,
prefix_colon=prefix_colon,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
)
if res != doutp:
diff(doutp, res, "input string")
print("\nroundtrip data:\n", res, sep="")
assert res == doutp
res = round_trip_dump(
data,
indent=indent,
block_seq_indent=block_seq_indent,
top_level_colon_align=top_level_colon_align,
prefix_colon=prefix_colon,
explicit_start=explicit_start,
explicit_end=explicit_end,
version=version,
)
print("roundtrip second round data:\n", res, sep="")
assert res == doutp
return data
def na_round_trip(
inp,
outp=None,
extra=None,
intermediate=None,
indent=None,
top_level_colon_align=None,
prefix_colon=None,
preserve_quotes=None,
explicit_start=None,
explicit_end=None,
version=None,
dump_data=None,
):
"""
inp: input string to parse
outp: expected output (equals input if not specified)
"""
inp = dedent(inp)
if outp is None:
outp = inp
if version is not None:
version = version
doutp = dedent(outp)
if extra is not None:
doutp += extra
yaml = YAML()
yaml.preserve_quotes = preserve_quotes
yaml.scalar_after_indicator = False # newline after every directives end
data = yaml.load(inp)
if dump_data:
print("data", data)
if intermediate is not None:
if isinstance(intermediate, dict):
for k, v in intermediate.items():
if data[k] != v:
print("{0!r} <> {1!r}".format(data[k], v))
raise ValueError
yaml.indent = indent
yaml.top_level_colon_align = top_level_colon_align
yaml.prefix_colon = prefix_colon
yaml.explicit_start = explicit_start
yaml.explicit_end = explicit_end
res = yaml.dump(data, compare=doutp)
return res
def YAML(**kw):
import srsly.ruamel_yaml # NOQA
class MyYAML(srsly.ruamel_yaml.YAML):
"""auto dedent string parameters on load"""
def load(self, stream):
if isinstance(stream, str):
if stream and stream[0] == "\n":
stream = stream[1:]
stream = textwrap.dedent(stream)
return srsly.ruamel_yaml.YAML.load(self, stream)
def load_all(self, stream):
if isinstance(stream, str):
if stream and stream[0] == "\n":
stream = stream[1:]
stream = textwrap.dedent(stream)
for d in srsly.ruamel_yaml.YAML.load_all(self, stream):
yield d
def dump(self, data, **kw):
from srsly.ruamel_yaml.compat import StringIO, BytesIO # NOQA
assert ("stream" in kw) ^ ("compare" in kw)
if "stream" in kw:
return srsly.ruamel_yaml.YAML.dump(data, **kw)
lkw = kw.copy()
expected = textwrap.dedent(lkw.pop("compare"))
unordered_lines = lkw.pop("unordered_lines", False)
if expected and expected[0] == "\n":
expected = expected[1:]
lkw["stream"] = st = StringIO()
srsly.ruamel_yaml.YAML.dump(self, data, **lkw)
res = st.getvalue()
print(res)
if unordered_lines:
res = sorted(res.splitlines())
expected = sorted(expected.splitlines())
assert res == expected
def round_trip(self, stream, **kw):
from srsly.ruamel_yaml.compat import StringIO, BytesIO # NOQA
assert isinstance(stream, (srsly.ruamel_yaml.compat.text_type, str))
lkw = kw.copy()
if stream and stream[0] == "\n":
stream = stream[1:]
stream = textwrap.dedent(stream)
data = srsly.ruamel_yaml.YAML.load(self, stream)
outp = lkw.pop("outp", stream)
lkw["stream"] = st = StringIO()
srsly.ruamel_yaml.YAML.dump(self, data, **lkw)
res = st.getvalue()
if res != outp:
diff(outp, res, "input string")
assert res == outp
def round_trip_all(self, stream, **kw):
from srsly.ruamel_yaml.compat import StringIO, BytesIO # NOQA
assert isinstance(stream, (srsly.ruamel_yaml.compat.text_type, str))
lkw = kw.copy()
if stream and stream[0] == "\n":
stream = stream[1:]
stream = textwrap.dedent(stream)
data = list(srsly.ruamel_yaml.YAML.load_all(self, stream))
outp = lkw.pop("outp", stream)
lkw["stream"] = st = StringIO()
srsly.ruamel_yaml.YAML.dump_all(self, data, **lkw)
res = st.getvalue()
if res != outp:
diff(outp, res, "input string")
assert res == outp
return MyYAML(**kw)
def save_and_run(program, base_dir=None, output=None, file_name=None, optimized=False):
"""
safe and run a python program, thereby circumventing any restrictions on module level
imports
"""
from subprocess import check_output, STDOUT, CalledProcessError
if not hasattr(base_dir, "hash"):
base_dir = Path(str(base_dir))
if file_name is None:
file_name = "safe_and_run_tmp.py"
file_name = base_dir / file_name
file_name.write_text(dedent(program))
try:
cmd = [sys.executable]
if optimized:
cmd.append("-O")
cmd.append(str(file_name))
print("running:", *cmd)
res = check_output(cmd, stderr=STDOUT, universal_newlines=True)
if output is not None:
if "__pypy__" in sys.builtin_module_names:
res = res.splitlines(True)
res = [line for line in res if "no version info" not in line]
res = "".join(res)
print("result: ", res, end="")
print("expected:", output, end="")
assert res == output
except CalledProcessError as exception:
print("##### Running '{} {}' FAILED #####".format(sys.executable, file_name))
print(exception.output)
return exception.returncode
return 0
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_a_dedent.py 0000775 0000000 0000000 00000002145 14742310675 0024516 0 ustar 00root root 0000000 0000000 from .roundtrip import dedent
class TestDedent:
def test_start_newline(self):
# fmt: off
x = dedent("""
123
456
""")
# fmt: on
assert x == "123\n 456\n"
def test_start_space_newline(self):
# special construct to prevent stripping of following whitespace
# fmt: off
x = dedent(" " """
123
""")
# fmt: on
assert x == "123\n"
def test_start_no_newline(self):
# special construct to prevent stripping of following whitespac
x = dedent(
"""\
123
456
"""
)
assert x == "123\n 456\n"
def test_preserve_no_newline_at_end(self):
x = dedent(
"""
123"""
)
assert x == "123"
def test_preserve_no_newline_at_all(self):
x = dedent(
"""\
123"""
)
assert x == "123"
def test_multiple_dedent(self):
x = dedent(
dedent(
"""
123
"""
)
)
assert x == "123\n"
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_add_xxx.py 0000775 0000000 0000000 00000012343 14742310675 0024413 0 ustar 00root root 0000000 0000000 # coding: utf-8
import re
import pytest # NOQA
from .roundtrip import dedent
# from PyYAML docs
class Dice(tuple):
def __new__(cls, a, b):
return tuple.__new__(cls, [a, b])
def __repr__(self):
return "Dice(%s,%s)" % self
def dice_constructor(loader, node):
value = loader.construct_scalar(node)
a, b = map(int, value.split("d"))
return Dice(a, b)
def dice_representer(dumper, data):
return dumper.represent_scalar(u"!dice", u"{}d{}".format(*data))
def test_dice_constructor():
import srsly.ruamel_yaml # NOQA
srsly.ruamel_yaml.add_constructor(u"!dice", dice_constructor)
with pytest.raises(ValueError):
data = srsly.ruamel_yaml.load(
"initial hit points: !dice 8d4", Loader=srsly.ruamel_yaml.Loader
)
assert str(data) == "{'initial hit points': Dice(8,4)}"
def test_dice_constructor_with_loader():
import srsly.ruamel_yaml # NOQA
with pytest.raises(ValueError):
srsly.ruamel_yaml.add_constructor(
u"!dice", dice_constructor, Loader=srsly.ruamel_yaml.Loader
)
data = srsly.ruamel_yaml.load(
"initial hit points: !dice 8d4", Loader=srsly.ruamel_yaml.Loader
)
assert str(data) == "{'initial hit points': Dice(8,4)}"
def test_dice_representer():
import srsly.ruamel_yaml # NOQA
srsly.ruamel_yaml.add_representer(Dice, dice_representer)
# srsly.ruamel_yaml 0.15.8+ no longer forces quotes tagged scalars
assert (
srsly.ruamel_yaml.dump(dict(gold=Dice(10, 6)), default_flow_style=False)
== "gold: !dice 10d6\n"
)
def test_dice_implicit_resolver():
import srsly.ruamel_yaml # NOQA
pattern = re.compile(r"^\d+d\d+$")
with pytest.raises(ValueError):
srsly.ruamel_yaml.add_implicit_resolver(u"!dice", pattern)
assert (
srsly.ruamel_yaml.dump(dict(treasure=Dice(10, 20)), default_flow_style=False)
== "treasure: 10d20\n"
)
assert srsly.ruamel_yaml.load(
"damage: 5d10", Loader=srsly.ruamel_yaml.Loader
) == dict(damage=Dice(5, 10))
class Obj1(dict):
def __init__(self, suffix):
self._suffix = suffix
self._node = None
def add_node(self, n):
self._node = n
def __repr__(self):
return "Obj1(%s->%s)" % (self._suffix, self.items())
def dump(self):
return repr(self._node)
class YAMLObj1(object):
yaml_tag = u"!obj:"
@classmethod
def from_yaml(cls, loader, suffix, node):
import srsly.ruamel_yaml # NOQA
obj1 = Obj1(suffix)
if isinstance(node, srsly.ruamel_yaml.MappingNode):
obj1.add_node(loader.construct_mapping(node))
else:
raise NotImplementedError
return obj1
@classmethod
def to_yaml(cls, dumper, data):
return dumper.represent_scalar(cls.yaml_tag + data._suffix, data.dump())
def test_yaml_obj():
import srsly.ruamel_yaml # NOQA
srsly.ruamel_yaml.add_representer(Obj1, YAMLObj1.to_yaml)
srsly.ruamel_yaml.add_multi_constructor(YAMLObj1.yaml_tag, YAMLObj1.from_yaml)
with pytest.raises(ValueError):
x = srsly.ruamel_yaml.load("!obj:x.2\na: 1", Loader=srsly.ruamel_yaml.Loader)
print(x)
assert srsly.ruamel_yaml.dump(x) == """!obj:x.2 "{'a': 1}"\n"""
def test_yaml_obj_with_loader_and_dumper():
import srsly.ruamel_yaml # NOQA
srsly.ruamel_yaml.add_representer(
Obj1, YAMLObj1.to_yaml, Dumper=srsly.ruamel_yaml.Dumper
)
srsly.ruamel_yaml.add_multi_constructor(
YAMLObj1.yaml_tag, YAMLObj1.from_yaml, Loader=srsly.ruamel_yaml.Loader
)
with pytest.raises(ValueError):
x = srsly.ruamel_yaml.load("!obj:x.2\na: 1", Loader=srsly.ruamel_yaml.Loader)
# x = srsly.ruamel_yaml.load('!obj:x.2\na: 1')
print(x)
assert srsly.ruamel_yaml.dump(x) == """!obj:x.2 "{'a': 1}"\n"""
# ToDo use nullege to search add_multi_representer and add_path_resolver
# and add some test code
# Issue 127 reported by Tommy Wang
def test_issue_127():
import srsly.ruamel_yaml # NOQA
class Ref(srsly.ruamel_yaml.YAMLObject):
yaml_constructor = srsly.ruamel_yaml.RoundTripConstructor
yaml_representer = srsly.ruamel_yaml.RoundTripRepresenter
yaml_tag = u"!Ref"
def __init__(self, logical_id):
self.logical_id = logical_id
@classmethod
def from_yaml(cls, loader, node):
return cls(loader.construct_scalar(node))
@classmethod
def to_yaml(cls, dumper, data):
if isinstance(data.logical_id, srsly.ruamel_yaml.scalarstring.ScalarString):
style = data.logical_id.style # srsly.ruamel_yaml>0.15.8
else:
style = None
return dumper.represent_scalar(cls.yaml_tag, data.logical_id, style=style)
document = dedent(
"""\
AList:
- !Ref One
- !Ref 'Two'
- !Ref
Two and a half
BList: [!Ref Three, !Ref "Four"]
CList:
- Five Six
- 'Seven Eight'
"""
)
data = srsly.ruamel_yaml.round_trip_load(document, preserve_quotes=True)
assert srsly.ruamel_yaml.round_trip_dump(
data, indent=4, block_seq_indent=2
) == document.replace("\n Two and", " Two and")
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_anchor.py 0000775 0000000 0000000 00000034134 14742310675 0024230 0 ustar 00root root 0000000 0000000 """
testing of anchors and the aliases referring to them
"""
import pytest
from textwrap import dedent
import platform
import srsly
from .roundtrip import (
round_trip,
dedent,
round_trip_load,
round_trip_dump,
YAML,
) # NOQA
def load(s):
return round_trip_load(dedent(s))
def compare(d, s):
assert round_trip_dump(d) == dedent(s)
class TestAnchorsAliases:
def test_anchor_id_renumber(self):
from srsly.ruamel_yaml.serializer import Serializer
assert Serializer.ANCHOR_TEMPLATE == "id%03d"
data = load(
"""
a: &id002
b: 1
c: 2
d: *id002
"""
)
compare(
data,
"""
a: &id001
b: 1
c: 2
d: *id001
""",
)
def test_template_matcher(self):
"""test if id matches the anchor template"""
from srsly.ruamel_yaml.serializer import templated_id
assert templated_id(u"id001")
assert templated_id(u"id999")
assert templated_id(u"id1000")
assert templated_id(u"id0001")
assert templated_id(u"id0000")
assert not templated_id(u"id02")
assert not templated_id(u"id000")
assert not templated_id(u"x000")
# def test_re_matcher(self):
# import re
# assert re.compile(u'id(?!000)\\d{3,}').match('id001')
# assert not re.compile(u'id(?!000\\d*)\\d{3,}').match('id000')
# assert re.compile(u'id(?!000$)\\d{3,}').match('id0001')
def test_anchor_assigned(self):
from srsly.ruamel_yaml.comments import CommentedMap
data = load(
"""
a: &id002
b: 1
c: 2
d: *id002
e: &etemplate
b: 1
c: 2
f: *etemplate
"""
)
d = data["d"]
assert isinstance(d, CommentedMap)
assert d.yaml_anchor() is None # got dropped as it matches pattern
e = data["e"]
assert isinstance(e, CommentedMap)
assert e.yaml_anchor().value == "etemplate"
assert e.yaml_anchor().always_dump is False
def test_anchor_id_retained(self):
data = load(
"""
a: &id002
b: 1
c: 2
d: *id002
e: &etemplate
b: 1
c: 2
f: *etemplate
"""
)
compare(
data,
"""
a: &id001
b: 1
c: 2
d: *id001
e: &etemplate
b: 1
c: 2
f: *etemplate
""",
)
@pytest.mark.skipif(
platform.python_implementation() == "Jython",
reason="Jython throws RepresenterError",
)
def test_alias_before_anchor(self):
from srsly.ruamel_yaml.composer import ComposerError
with pytest.raises(ComposerError):
data = load(
"""
d: *id002
a: &id002
b: 1
c: 2
"""
)
data = data
def test_anchor_on_sequence(self):
# as reported by Bjorn Stabell
# https://bitbucket.org/ruamel/yaml/issue/7/anchor-names-not-preserved
from srsly.ruamel_yaml.comments import CommentedSeq
data = load(
"""
nut1: &alice
- 1
- 2
nut2: &blake
- some data
- *alice
nut3:
- *blake
- *alice
"""
)
r = data["nut1"]
assert isinstance(r, CommentedSeq)
assert r.yaml_anchor() is not None
assert r.yaml_anchor().value == "alice"
merge_yaml = dedent(
"""
- &CENTER {x: 1, y: 2}
- &LEFT {x: 0, y: 2}
- &BIG {r: 10}
- &SMALL {r: 1}
# All the following maps are equal:
# Explicit keys
- x: 1
y: 2
r: 10
label: center/small
# Merge one map
- <<: *CENTER
r: 10
label: center/medium
# Merge multiple maps
- <<: [*CENTER, *BIG]
label: center/big
# Override
- <<: [*BIG, *LEFT, *SMALL]
x: 1
label: center/huge
"""
)
def test_merge_00(self):
data = load(self.merge_yaml)
d = data[4]
ok = True
for k in d:
for o in [5, 6, 7]:
x = d.get(k)
y = data[o].get(k)
if not isinstance(x, int):
x = x.split("/")[0]
y = y.split("/")[0]
if x != y:
ok = False
print("key", k, d.get(k), data[o].get(k))
assert ok
def test_merge_accessible(self):
from srsly.ruamel_yaml.comments import CommentedMap, merge_attrib
data = load(
"""
k: &level_2 { a: 1, b2 }
l: &level_1 { a: 10, c: 3 }
m:
<<: *level_1
c: 30
d: 40
"""
)
d = data["m"]
assert isinstance(d, CommentedMap)
assert hasattr(d, merge_attrib)
def test_merge_01(self):
data = load(self.merge_yaml)
compare(data, self.merge_yaml)
def test_merge_nested(self):
yaml = """
a:
<<: &content
1: plugh
2: plover
0: xyzzy
b:
<<: *content
"""
data = round_trip(yaml) # NOQA
def test_merge_nested_with_sequence(self):
yaml = """
a:
<<: &content
<<: &y2
1: plugh
2: plover
0: xyzzy
b:
<<: [*content, *y2]
"""
data = round_trip(yaml) # NOQA
def test_add_anchor(self):
from srsly.ruamel_yaml.comments import CommentedMap
data = CommentedMap()
data_a = CommentedMap()
data["a"] = data_a
data_a["c"] = 3
data["b"] = 2
data.yaml_set_anchor("klm", always_dump=True)
data["a"].yaml_set_anchor("xyz", always_dump=True)
compare(
data,
"""
&klm
a: &xyz
c: 3
b: 2
""",
)
# this is an error in PyYAML
def test_reused_anchor(self):
from srsly.ruamel_yaml.error import ReusedAnchorWarning
yaml = """
- &a
x: 1
- <<: *a
- &a
x: 2
- <<: *a
"""
with pytest.warns(ReusedAnchorWarning):
data = round_trip(yaml) # NOQA
def test_issue_130(self):
# issue 130 reported by Devid Fee
ys = dedent(
"""\
components:
server: &server_component
type: spark.server:ServerComponent
host: 0.0.0.0
port: 8000
shell: &shell_component
type: spark.shell:ShellComponent
services:
server: &server_service
<<: *server_component
shell: &shell_service
<<: *shell_component
components:
server: {<<: *server_service}
"""
)
data = srsly.ruamel_yaml.safe_load(ys)
assert data["services"]["shell"]["components"]["server"]["port"] == 8000
def test_issue_130a(self):
# issue 130 reported by Devid Fee
ys = dedent(
"""\
components:
server: &server_component
type: spark.server:ServerComponent
host: 0.0.0.0
port: 8000
shell: &shell_component
type: spark.shell:ShellComponent
services:
server: &server_service
<<: *server_component
port: 4000
shell: &shell_service
<<: *shell_component
components:
server: {<<: *server_service}
"""
)
data = srsly.ruamel_yaml.safe_load(ys)
assert data["services"]["shell"]["components"]["server"]["port"] == 4000
class TestMergeKeysValues:
yaml_str = dedent(
"""\
- &mx
a: x1
b: x2
c: x3
- &my
a: y1
b: y2 # masked by the one in &mx
d: y4
-
a: 1
<<: [*mx, *my]
m: 6
"""
)
# in the following d always has "expanded" the merges
def test_merge_for(self):
from srsly.ruamel_yaml import safe_load
d = safe_load(self.yaml_str)
data = round_trip_load(self.yaml_str)
count = 0
for x in data[2]:
count += 1
print(count, x)
assert count == len(d[2])
def test_merge_keys(self):
from srsly.ruamel_yaml import safe_load
d = safe_load(self.yaml_str)
data = round_trip_load(self.yaml_str)
count = 0
for x in data[2].keys():
count += 1
print(count, x)
assert count == len(d[2])
def test_merge_values(self):
from srsly.ruamel_yaml import safe_load
d = safe_load(self.yaml_str)
data = round_trip_load(self.yaml_str)
count = 0
for x in data[2].values():
count += 1
print(count, x)
assert count == len(d[2])
def test_merge_items(self):
from srsly.ruamel_yaml import safe_load
d = safe_load(self.yaml_str)
data = round_trip_load(self.yaml_str)
count = 0
for x in data[2].items():
count += 1
print(count, x)
assert count == len(d[2])
def test_len_items_delete(self):
from srsly.ruamel_yaml import safe_load
from srsly.ruamel_yaml.compat import PY3
d = safe_load(self.yaml_str)
data = round_trip_load(self.yaml_str)
x = data[2].items()
print("d2 items", d[2].items(), len(d[2].items()), x, len(x))
ref = len(d[2].items())
print("ref", ref)
assert len(x) == ref
del data[2]["m"]
if PY3:
ref -= 1
assert len(x) == ref
del data[2]["d"]
if PY3:
ref -= 1
assert len(x) == ref
del data[2]["a"]
if PY3:
ref -= 1
assert len(x) == ref
def test_issue_196_cast_of_dict(self, capsys):
from srsly.ruamel_yaml import YAML
yaml = YAML()
mapping = yaml.load(
"""\
anchored: &anchor
a : 1
mapping:
<<: *anchor
b: 2
"""
)["mapping"]
for k in mapping:
print("k", k)
for k in mapping.copy():
print("kc", k)
print("v", list(mapping.keys()))
print("v", list(mapping.values()))
print("v", list(mapping.items()))
print(len(mapping))
print("-----")
# print({**mapping})
# print(type({**mapping}))
# assert 'a' in {**mapping}
assert "a" in mapping
x = {}
for k in mapping:
x[k] = mapping[k]
assert "a" in x
assert "a" in mapping.keys()
assert mapping["a"] == 1
assert mapping.__getitem__("a") == 1
assert "a" in dict(mapping)
assert "a" in dict(mapping.items())
def test_values_of_merged(self):
from srsly.ruamel_yaml import YAML
yaml = YAML()
data = yaml.load(dedent(self.yaml_str))
assert list(data[2].values()) == [1, 6, "x2", "x3", "y4"]
def test_issue_213_copy_of_merge(self):
from srsly.ruamel_yaml import YAML
yaml = YAML()
d = yaml.load(
"""\
foo: &foo
a: a
foo2:
<<: *foo
b: b
"""
)["foo2"]
assert d["a"] == "a"
d2 = d.copy()
assert d2["a"] == "a"
print("d", d)
del d["a"]
assert "a" not in d
assert "a" in d2
class TestDuplicateKeyThroughAnchor:
def test_duplicate_key_00(self):
from srsly.ruamel_yaml import version_info
from srsly.ruamel_yaml import safe_load, round_trip_load
from srsly.ruamel_yaml.constructor import (
DuplicateKeyFutureWarning,
DuplicateKeyError,
)
s = dedent(
"""\
&anchor foo:
foo: bar
*anchor : duplicate key
baz: bat
*anchor : duplicate key
"""
)
if version_info < (0, 15, 1):
pass
elif version_info < (0, 16, 0):
with pytest.warns(DuplicateKeyFutureWarning):
safe_load(s)
with pytest.warns(DuplicateKeyFutureWarning):
round_trip_load(s)
else:
with pytest.raises(DuplicateKeyError):
safe_load(s)
with pytest.raises(DuplicateKeyError):
round_trip_load(s)
def test_duplicate_key_01(self):
# so issue https://stackoverflow.com/a/52852106/1307905
from srsly.ruamel_yaml import version_info
from srsly.ruamel_yaml.constructor import DuplicateKeyError
s = dedent(
"""\
- &name-name
a: 1
- &help-name
b: 2
- <<: *name-name
<<: *help-name
"""
)
if version_info < (0, 15, 1):
pass
else:
with pytest.raises(DuplicateKeyError):
yaml = YAML(typ="safe")
yaml.load(s)
with pytest.raises(DuplicateKeyError):
yaml = YAML()
yaml.load(s)
class TestFullCharSetAnchors:
def test_master_of_orion(self):
# https://bitbucket.org/ruamel/yaml/issues/72/not-allowed-in-anchor-names
# submitted by Shalon Wood
yaml_str = """
- collection: &Backend.Civilizations.RacialPerk
items:
- key: perk_population_growth_modifier
- *Backend.Civilizations.RacialPerk
"""
data = load(yaml_str) # NOQA
def test_roundtrip_00(self):
yaml_str = """
- &dotted.words.here
a: 1
b: 2
- *dotted.words.here
"""
data = round_trip(yaml_str) # NOQA
def test_roundtrip_01(self):
yaml_str = """
- &dotted.words.here[a, b]
- *dotted.words.here
"""
data = load(yaml_str) # NOQA
compare(data, yaml_str.replace("[", " [")) # an extra space is inserted
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_api_change.py 0000775 0000000 0000000 00000015646 14742310675 0025043 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
"""
testing of anchors and the aliases referring to them
"""
import sys
import textwrap
import pytest
from pathlib import Path
pytestmark = pytest.mark.filterwarnings(
"ignore::pytest.PytestUnraisableExceptionWarning"
)
class TestNewAPI:
def test_duplicate_keys_00(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.constructor import DuplicateKeyError
yaml = YAML()
with pytest.raises(DuplicateKeyError):
yaml.load("{a: 1, a: 2}")
def test_duplicate_keys_01(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.constructor import DuplicateKeyError
yaml = YAML(typ="safe", pure=True)
with pytest.raises(DuplicateKeyError):
yaml.load("{a: 1, a: 2}")
def test_duplicate_keys_02(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.constructor import DuplicateKeyError
yaml = YAML(typ="safe")
with pytest.raises(DuplicateKeyError):
yaml.load("{a: 1, a: 2}")
def test_issue_135(self):
# reported by Andrzej Ostrowski
from srsly.ruamel_yaml import YAML
data = {"a": 1, "b": 2}
yaml = YAML(typ="safe")
# originally on 2.7: with pytest.raises(TypeError):
yaml.dump(data, sys.stdout)
def test_issue_135_temporary_workaround(self):
# never raised error
from srsly.ruamel_yaml import YAML
data = {"a": 1, "b": 2}
yaml = YAML(typ="safe", pure=True)
yaml.dump(data, sys.stdout)
class TestWrite:
def test_dump_path(self, tmpdir):
from srsly.ruamel_yaml import YAML
fn = Path(str(tmpdir)) / "test.yaml"
yaml = YAML()
data = yaml.map()
data["a"] = 1
data["b"] = 2
yaml.dump(data, fn)
assert fn.read_text() == "a: 1\nb: 2\n"
def test_dump_file(self, tmpdir):
from srsly.ruamel_yaml import YAML
fn = Path(str(tmpdir)) / "test.yaml"
yaml = YAML()
data = yaml.map()
data["a"] = 1
data["b"] = 2
with open(str(fn), "w") as fp:
yaml.dump(data, fp)
assert fn.read_text() == "a: 1\nb: 2\n"
def test_dump_missing_stream(self):
from srsly.ruamel_yaml import YAML
yaml = YAML()
data = yaml.map()
data["a"] = 1
data["b"] = 2
with pytest.raises(TypeError):
yaml.dump(data)
def test_dump_too_many_args(self, tmpdir):
from srsly.ruamel_yaml import YAML
fn = Path(str(tmpdir)) / "test.yaml"
yaml = YAML()
data = yaml.map()
data["a"] = 1
data["b"] = 2
with pytest.raises(TypeError):
yaml.dump(data, fn, True)
def test_transform(self, tmpdir):
from srsly.ruamel_yaml import YAML
def tr(s):
return s.replace(" ", " ")
fn = Path(str(tmpdir)) / "test.yaml"
yaml = YAML()
data = yaml.map()
data["a"] = 1
data["b"] = 2
yaml.dump(data, fn, transform=tr)
assert fn.read_text() == "a: 1\nb: 2\n"
def test_print(self, capsys):
from srsly.ruamel_yaml import YAML
yaml = YAML()
data = yaml.map()
data["a"] = 1
data["b"] = 2
yaml.dump(data, sys.stdout)
out, err = capsys.readouterr()
assert out == "a: 1\nb: 2\n"
class TestRead:
def test_multi_load(self):
# make sure reader, scanner, parser get reset
from srsly.ruamel_yaml import YAML
yaml = YAML()
yaml.load("a: 1")
yaml.load("a: 1") # did not work in 0.15.4
def test_parse(self):
# ensure `parse` method is functional and can parse "unsafe" yaml
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.constructor import ConstructorError
yaml = YAML(typ="safe")
s = "- !User0 {age: 18, name: Anthon}"
# should fail to load
with pytest.raises(ConstructorError):
yaml.load(s)
# should parse fine
yaml = YAML(typ="safe")
for _ in yaml.parse(s):
pass
class TestLoadAll:
def test_multi_document_load(self, tmpdir):
"""this went wrong on 3.7 because of StopIteration, PR 37 and Issue 211"""
from srsly.ruamel_yaml import YAML
fn = Path(str(tmpdir)) / "test.yaml"
fn.write_text(
textwrap.dedent(
u"""\
---
- a
---
- b
...
"""
)
)
yaml = YAML()
assert list(yaml.load_all(fn)) == [["a"], ["b"]]
class TestDuplSet:
def test_dupl_set_00(self):
# round-trip-loader should except
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.constructor import DuplicateKeyError
yaml = YAML()
with pytest.raises(DuplicateKeyError):
yaml.load(
textwrap.dedent(
"""\
!!set
? a
? b
? c
? a
"""
)
)
class TestDumpLoadUnicode:
# test triggered by SamH on stackoverflow (https://stackoverflow.com/q/45281596/1307905)
# and answer by randomir (https://stackoverflow.com/a/45281922/1307905)
def test_write_unicode(self, tmpdir):
from srsly.ruamel_yaml import YAML
yaml = YAML()
text_dict = {"text": u"HELLO_WORLD©"}
file_name = str(tmpdir) + "/tstFile.yaml"
yaml.dump(text_dict, open(file_name, "w", encoding="utf8", newline="\n"))
assert open(file_name, "rb").read().decode("utf-8") == u"text: HELLO_WORLD©\n"
def test_read_unicode(self, tmpdir):
from srsly.ruamel_yaml import YAML
yaml = YAML()
file_name = str(tmpdir) + "/tstFile.yaml"
with open(file_name, "wb") as fp:
fp.write(u"text: HELLO_WORLD©\n".encode("utf-8"))
with open(file_name, "r", encoding="utf8") as fp:
text_dict = yaml.load(fp)
assert text_dict["text"] == u"HELLO_WORLD©"
class TestFlowStyle:
def test_flow_style(self, capsys):
# https://stackoverflow.com/questions/45791712/
from srsly.ruamel_yaml import YAML
yaml = YAML()
yaml.default_flow_style = None
data = yaml.map()
data["b"] = 1
data["a"] = [[1, 2], [3, 4]]
yaml.dump(data, sys.stdout)
out, err = capsys.readouterr()
assert out == "b: 1\na:\n- [1, 2]\n- [3, 4]\n"
class TestOldAPI:
@pytest.mark.skipif(sys.version_info >= (3, 0), reason="ok on Py3")
def test_duplicate_keys_02(self):
# Issue 165 unicode keys in error/warning
from srsly.ruamel_yaml import safe_load
from srsly.ruamel_yaml.constructor import DuplicateKeyError
with pytest.raises(DuplicateKeyError):
safe_load("type: Doméstica\ntype: International")
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_appliance.py 0000775 0000000 0000000 00000017000 14742310675 0024703 0 ustar 00root root 0000000 0000000 from __future__ import print_function
import sys
import os
import types
import traceback
import pprint
import argparse
from srsly.ruamel_yaml.compat import PY3
# DATA = 'tests/data'
# determine the position of data dynamically relative to program
# this allows running test while the current path is not the top of the
# repository, e.g. from the tests/data directory: python ../test_yaml.py
DATA = __file__.rsplit(os.sep, 2)[0] + "/data"
def find_test_functions(collections):
if not isinstance(collections, list):
collections = [collections]
functions = []
for collection in collections:
if not isinstance(collection, dict):
collection = vars(collection)
for key in sorted(collection):
value = collection[key]
if isinstance(value, types.FunctionType) and hasattr(value, "unittest"):
functions.append(value)
return functions
def find_test_filenames(directory):
filenames = {}
for filename in os.listdir(directory):
if os.path.isfile(os.path.join(directory, filename)):
base, ext = os.path.splitext(filename)
if base.endswith("-py2" if PY3 else "-py3"):
continue
filenames.setdefault(base, []).append(ext)
filenames = sorted(filenames.items())
return filenames
def parse_arguments(args):
""""""
parser = argparse.ArgumentParser(
usage=""" run the yaml tests. By default
all functions on all appropriate test_files are run. Functions have
unittest attributes that determine the required extensions to filenames
that need to be available in order to run that test. E.g.\n\n
python test_yaml.py test_constructor_types\n
python test_yaml.py --verbose test_tokens spec-02-05\n\n
The presence of an extension in the .skip attribute of a function
disables the test for that function."""
)
# ToDo: make into int and test > 0 in functions
parser.add_argument(
"--verbose",
"-v",
action="store_true",
default="YAML_TEST_VERBOSE" in os.environ,
help="set verbosity output",
)
parser.add_argument(
"--list-functions",
action="store_true",
help="""list all functions with required file extensions for test files
""",
)
parser.add_argument("function", nargs="?", help="""restrict function to run""")
parser.add_argument(
"filenames",
nargs="*",
help="""basename of filename set, extensions (.code, .data) have to
be a superset of those in the unittest attribute of the selected
function""",
)
args = parser.parse_args(args)
# print('args', args)
verbose = args.verbose
include_functions = [args.function] if args.function else []
include_filenames = args.filenames
# if args is None:
# args = sys.argv[1:]
# verbose = False
# if '-v' in args:
# verbose = True
# args.remove('-v')
# if '--verbose' in args:
# verbose = True
# args.remove('--verbose') # never worked without this
# if 'YAML_TEST_VERBOSE' in os.environ:
# verbose = True
# include_functions = []
# if args:
# include_functions.append(args.pop(0))
if "YAML_TEST_FUNCTIONS" in os.environ:
include_functions.extend(os.environ["YAML_TEST_FUNCTIONS"].split())
# include_filenames = []
# include_filenames.extend(args)
if "YAML_TEST_FILENAMES" in os.environ:
include_filenames.extend(os.environ["YAML_TEST_FILENAMES"].split())
return include_functions, include_filenames, verbose, args
def execute(function, filenames, verbose):
if PY3:
name = function.__name__
else:
if hasattr(function, "unittest_name"):
name = function.unittest_name
else:
name = function.func_name
if verbose:
sys.stdout.write("=" * 75 + "\n")
sys.stdout.write("%s(%s)...\n" % (name, ", ".join(filenames)))
try:
function(verbose=verbose, *filenames)
except Exception as exc:
info = sys.exc_info()
if isinstance(exc, AssertionError):
kind = "FAILURE"
else:
kind = "ERROR"
if verbose:
traceback.print_exc(limit=1, file=sys.stdout)
else:
sys.stdout.write(kind[0])
sys.stdout.flush()
else:
kind = "SUCCESS"
info = None
if not verbose:
sys.stdout.write(".")
sys.stdout.flush()
return (name, filenames, kind, info)
def display(results, verbose):
if results and not verbose:
sys.stdout.write("\n")
total = len(results)
failures = 0
errors = 0
for name, filenames, kind, info in results:
if kind == "SUCCESS":
continue
if kind == "FAILURE":
failures += 1
if kind == "ERROR":
errors += 1
sys.stdout.write("=" * 75 + "\n")
sys.stdout.write("%s(%s): %s\n" % (name, ", ".join(filenames), kind))
if kind == "ERROR":
traceback.print_exception(file=sys.stdout, *info)
else:
sys.stdout.write("Traceback (most recent call last):\n")
traceback.print_tb(info[2], file=sys.stdout)
sys.stdout.write("%s: see below\n" % info[0].__name__)
sys.stdout.write("~" * 75 + "\n")
for arg in info[1].args:
pprint.pprint(arg, stream=sys.stdout)
for filename in filenames:
sys.stdout.write("-" * 75 + "\n")
sys.stdout.write("%s:\n" % filename)
if PY3:
with open(filename, "r", errors="replace") as fp:
data = fp.read()
else:
with open(filename, "rb") as fp:
data = fp.read()
sys.stdout.write(data)
if data and data[-1] != "\n":
sys.stdout.write("\n")
sys.stdout.write("=" * 75 + "\n")
sys.stdout.write("TESTS: %s\n" % total)
ret_val = 0
if failures:
sys.stdout.write("FAILURES: %s\n" % failures)
ret_val = 1
if errors:
sys.stdout.write("ERRORS: %s\n" % errors)
ret_val = 2
return ret_val
def run(collections, args=None):
test_functions = find_test_functions(collections)
test_filenames = find_test_filenames(DATA)
include_functions, include_filenames, verbose, a = parse_arguments(args)
if a.list_functions:
print("test functions:")
for f in test_functions:
print(" {:30s} {}".format(f.__name__, f.unittest))
return
results = []
for function in test_functions:
if include_functions and function.__name__ not in include_functions:
continue
if function.unittest:
for base, exts in test_filenames:
if include_filenames and base not in include_filenames:
continue
filenames = []
for ext in function.unittest:
if ext not in exts:
break
filenames.append(os.path.join(DATA, base + ext))
else:
skip_exts = getattr(function, "skip", [])
for skip_ext in skip_exts:
if skip_ext in exts:
break
else:
result = execute(function, filenames, verbose)
results.append(result)
else:
result = execute(function, [], verbose)
results.append(result)
return display(results, verbose=verbose)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_class_register.py 0000775 0000000 0000000 00000006414 14742310675 0025767 0 ustar 00root root 0000000 0000000 # coding: utf-8
"""
testing of YAML.register_class and @yaml_object
"""
from .roundtrip import YAML
class User0(object):
def __init__(self, name, age):
self.name = name
self.age = age
class User1(object):
yaml_tag = u"!user"
def __init__(self, name, age):
self.name = name
self.age = age
@classmethod
def to_yaml(cls, representer, node):
return representer.represent_scalar(
cls.yaml_tag, u"{.name}-{.age}".format(node, node)
)
@classmethod
def from_yaml(cls, constructor, node):
return cls(*node.value.split("-"))
class TestRegisterClass(object):
def test_register_0_rt(self):
yaml = YAML()
yaml.register_class(User0)
ys = """
- !User0
name: Anthon
age: 18
"""
d = yaml.load(ys)
yaml.dump(d, compare=ys, unordered_lines=True)
def test_register_0_safe(self):
# default_flow_style = None
yaml = YAML(typ="safe")
yaml.register_class(User0)
ys = """
- !User0 {age: 18, name: Anthon}
"""
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_register_0_unsafe(self):
# default_flow_style = None
yaml = YAML(typ="unsafe")
yaml.register_class(User0)
ys = """
- !User0 {age: 18, name: Anthon}
"""
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_register_1_rt(self):
yaml = YAML()
yaml.register_class(User1)
ys = """
- !user Anthon-18
"""
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_register_1_safe(self):
yaml = YAML(typ="safe")
yaml.register_class(User1)
ys = """
[!user Anthon-18]
"""
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_register_1_unsafe(self):
yaml = YAML(typ="unsafe")
yaml.register_class(User1)
ys = """
[!user Anthon-18]
"""
d = yaml.load(ys)
yaml.dump(d, compare=ys)
class TestDecorator(object):
def test_decorator_implicit(self):
from srsly.ruamel_yaml import yaml_object
yml = YAML()
@yaml_object(yml)
class User2(object):
def __init__(self, name, age):
self.name = name
self.age = age
ys = """
- !User2
name: Anthon
age: 18
"""
d = yml.load(ys)
yml.dump(d, compare=ys, unordered_lines=True)
def test_decorator_explicit(self):
from srsly.ruamel_yaml import yaml_object
yml = YAML()
@yaml_object(yml)
class User3(object):
yaml_tag = u"!USER"
def __init__(self, name, age):
self.name = name
self.age = age
@classmethod
def to_yaml(cls, representer, node):
return representer.represent_scalar(
cls.yaml_tag, u"{.name}-{.age}".format(node, node)
)
@classmethod
def from_yaml(cls, constructor, node):
return cls(*node.value.split("-"))
ys = """
- !USER Anthon-18
"""
d = yml.load(ys)
yml.dump(d, compare=ys)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_collections.py 0000775 0000000 0000000 00000000774 14742310675 0025277 0 ustar 00root root 0000000 0000000 # coding: utf-8
"""
collections.OrderedDict is a new class not supported by PyYAML (issue 83 by Frazer McLean)
This is now so integrated in Python that it can be mapped to !!omap
"""
import pytest # NOQA
from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
class TestOrderedDict:
def test_ordereddict(self):
from collections import OrderedDict
import srsly.ruamel_yaml # NOQA
assert srsly.ruamel_yaml.dump(OrderedDict()) == "!!omap []\n"
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_comment_manipulation.py 0000775 0000000 0000000 00000034773 14742310675 0027211 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
import pytest # NOQA
from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
def load(s):
return round_trip_load(dedent(s))
def compare(data, s, **kw):
assert round_trip_dump(data, **kw) == dedent(s)
def compare_eol(data, s):
assert "EOL" in s
ds = dedent(s).replace("EOL", "").replace("\n", "|\n")
assert round_trip_dump(data).replace("\n", "|\n") == ds
class TestCommentsManipulation:
# list
def test_seq_set_comment_on_existing_explicit_column(self):
data = load(
"""
- a # comment 1
- b
- c
"""
)
data.yaml_add_eol_comment("comment 2", key=1, column=6)
exp = """
- a # comment 1
- b # comment 2
- c
"""
compare(data, exp)
def test_seq_overwrite_comment_on_existing_explicit_column(self):
data = load(
"""
- a # comment 1
- b
- c
"""
)
data.yaml_add_eol_comment("comment 2", key=0, column=6)
exp = """
- a # comment 2
- b
- c
"""
compare(data, exp)
def test_seq_first_comment_explicit_column(self):
data = load(
"""
- a
- b
- c
"""
)
data.yaml_add_eol_comment("comment 1", key=1, column=6)
exp = """
- a
- b # comment 1
- c
"""
compare(data, exp)
def test_seq_set_comment_on_existing_column_prev(self):
data = load(
"""
- a # comment 1
- b
- c
- d # comment 3
"""
)
data.yaml_add_eol_comment("comment 2", key=1)
exp = """
- a # comment 1
- b # comment 2
- c
- d # comment 3
"""
compare(data, exp)
def test_seq_set_comment_on_existing_column_next(self):
data = load(
"""
- a # comment 1
- b
- c
- d # comment 3
"""
)
print(data._yaml_comment)
# print(type(data._yaml_comment._items[0][0].start_mark))
# srsly.ruamel_yaml.error.Mark
# print(type(data._yaml_comment._items[0][0].start_mark))
data.yaml_add_eol_comment("comment 2", key=2)
exp = """
- a # comment 1
- b
- c # comment 2
- d # comment 3
"""
compare(data, exp)
def test_seq_set_comment_on_existing_column_further_away(self):
"""
no comment line before or after, take the latest before
the new position
"""
data = load(
"""
- a # comment 1
- b
- c
- d
- e
- f # comment 3
"""
)
print(data._yaml_comment)
# print(type(data._yaml_comment._items[0][0].start_mark))
# srsly.ruamel_yaml.error.Mark
# print(type(data._yaml_comment._items[0][0].start_mark))
data.yaml_add_eol_comment("comment 2", key=3)
exp = """
- a # comment 1
- b
- c
- d # comment 2
- e
- f # comment 3
"""
compare(data, exp)
def test_seq_set_comment_on_existing_explicit_column_with_hash(self):
data = load(
"""
- a # comment 1
- b
- c
"""
)
data.yaml_add_eol_comment("# comment 2", key=1, column=6)
exp = """
- a # comment 1
- b # comment 2
- c
"""
compare(data, exp)
# dict
def test_dict_set_comment_on_existing_explicit_column(self):
data = load(
"""
a: 1 # comment 1
b: 2
c: 3
d: 4
e: 5
"""
)
data.yaml_add_eol_comment("comment 2", key="c", column=7)
exp = """
a: 1 # comment 1
b: 2
c: 3 # comment 2
d: 4
e: 5
"""
compare(data, exp)
def test_dict_overwrite_comment_on_existing_explicit_column(self):
data = load(
"""
a: 1 # comment 1
b: 2
c: 3
d: 4
e: 5
"""
)
data.yaml_add_eol_comment("comment 2", key="a", column=7)
exp = """
a: 1 # comment 2
b: 2
c: 3
d: 4
e: 5
"""
compare(data, exp)
def test_map_set_comment_on_existing_column_prev(self):
data = load(
"""
a: 1 # comment 1
b: 2
c: 3
d: 4
e: 5 # comment 3
"""
)
data.yaml_add_eol_comment("comment 2", key="b")
exp = """
a: 1 # comment 1
b: 2 # comment 2
c: 3
d: 4
e: 5 # comment 3
"""
compare(data, exp)
def test_map_set_comment_on_existing_column_next(self):
data = load(
"""
a: 1 # comment 1
b: 2
c: 3
d: 4
e: 5 # comment 3
"""
)
data.yaml_add_eol_comment("comment 2", key="d")
exp = """
a: 1 # comment 1
b: 2
c: 3
d: 4 # comment 2
e: 5 # comment 3
"""
compare(data, exp)
def test_map_set_comment_on_existing_column_further_away(self):
"""
no comment line before or after, take the latest before
the new position
"""
data = load(
"""
a: 1 # comment 1
b: 2
c: 3
d: 4
e: 5 # comment 3
"""
)
data.yaml_add_eol_comment("comment 2", key="c")
print(round_trip_dump(data))
exp = """
a: 1 # comment 1
b: 2
c: 3 # comment 2
d: 4
e: 5 # comment 3
"""
compare(data, exp)
def test_before_top_map_rt(self):
data = load(
"""
a: 1
b: 2
"""
)
data.yaml_set_start_comment("Hello\nWorld\n")
exp = """
# Hello
# World
a: 1
b: 2
"""
compare(data, exp.format(comment="#"))
def test_before_top_map_replace(self):
data = load(
"""
# abc
# def
a: 1 # 1
b: 2
"""
)
data.yaml_set_start_comment("Hello\nWorld\n")
exp = """
# Hello
# World
a: 1 # 1
b: 2
"""
compare(data, exp.format(comment="#"))
def test_before_top_map_from_scratch(self):
from srsly.ruamel_yaml.comments import CommentedMap
data = CommentedMap()
data["a"] = 1
data["b"] = 2
data.yaml_set_start_comment("Hello\nWorld\n")
# print(data.ca)
# print(data.ca._items)
exp = """
# Hello
# World
a: 1
b: 2
"""
compare(data, exp.format(comment="#"))
def test_before_top_seq_rt(self):
data = load(
"""
- a
- b
"""
)
data.yaml_set_start_comment("Hello\nWorld\n")
print(round_trip_dump(data))
exp = """
# Hello
# World
- a
- b
"""
compare(data, exp)
def test_before_top_seq_rt_replace(self):
s = """
# this
# that
- a
- b
"""
data = load(s.format(comment="#"))
data.yaml_set_start_comment("Hello\nWorld\n")
print(round_trip_dump(data))
exp = """
# Hello
# World
- a
- b
"""
compare(data, exp.format(comment="#"))
def test_before_top_seq_from_scratch(self):
from srsly.ruamel_yaml.comments import CommentedSeq
data = CommentedSeq()
data.append("a")
data.append("b")
data.yaml_set_start_comment("Hello\nWorld\n")
print(round_trip_dump(data))
exp = """
# Hello
# World
- a
- b
"""
compare(data, exp.format(comment="#"))
# nested variants
def test_before_nested_map_rt(self):
data = load(
"""
a: 1
b:
c: 2
d: 3
"""
)
data["b"].yaml_set_start_comment("Hello\nWorld\n")
exp = """
a: 1
b:
# Hello
# World
c: 2
d: 3
"""
compare(data, exp.format(comment="#"))
def test_before_nested_map_rt_indent(self):
data = load(
"""
a: 1
b:
c: 2
d: 3
"""
)
data["b"].yaml_set_start_comment("Hello\nWorld\n", indent=2)
exp = """
a: 1
b:
# Hello
# World
c: 2
d: 3
"""
compare(data, exp.format(comment="#"))
print(data["b"].ca)
def test_before_nested_map_from_scratch(self):
from srsly.ruamel_yaml.comments import CommentedMap
data = CommentedMap()
datab = CommentedMap()
data["a"] = 1
data["b"] = datab
datab["c"] = 2
datab["d"] = 3
data["b"].yaml_set_start_comment("Hello\nWorld\n")
exp = """
a: 1
b:
# Hello
# World
c: 2
d: 3
"""
compare(data, exp.format(comment="#"))
def test_before_nested_seq_from_scratch(self):
from srsly.ruamel_yaml.comments import CommentedMap, CommentedSeq
data = CommentedMap()
datab = CommentedSeq()
data["a"] = 1
data["b"] = datab
datab.append("c")
datab.append("d")
data["b"].yaml_set_start_comment("Hello\nWorld\n", indent=2)
exp = """
a: 1
b:
# Hello
# World
- c
- d
"""
compare(data, exp.format(comment="#"))
def test_before_nested_seq_from_scratch_block_seq_indent(self):
from srsly.ruamel_yaml.comments import CommentedMap, CommentedSeq
data = CommentedMap()
datab = CommentedSeq()
data["a"] = 1
data["b"] = datab
datab.append("c")
datab.append("d")
data["b"].yaml_set_start_comment("Hello\nWorld\n", indent=2)
exp = """
a: 1
b:
# Hello
# World
- c
- d
"""
compare(data, exp.format(comment="#"), indent=4, block_seq_indent=2)
def test_map_set_comment_before_and_after_non_first_key_00(self):
# http://stackoverflow.com/a/40705671/1307905
data = load(
"""
xyz:
a: 1 # comment 1
b: 2
test1:
test2:
test3: 3
"""
)
data.yaml_set_comment_before_after_key(
"test1", "before test1 (top level)", after="before test2"
)
data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
exp = """
xyz:
a: 1 # comment 1
b: 2
# before test1 (top level)
test1:
# before test2
test2:
# after test2
test3: 3
"""
compare(data, exp)
def Xtest_map_set_comment_before_and_after_non_first_key_01(self):
data = load(
"""
xyz:
a: 1 # comment 1
b: 2
test1:
test2:
test3: 3
"""
)
data.yaml_set_comment_before_after_key(
"test1", "before test1 (top level)", after="before test2\n\n"
)
data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
# EOL is needed here as dedenting gets rid of spaces (as well as does Emacs
exp = """
xyz:
a: 1 # comment 1
b: 2
# before test1 (top level)
test1:
# before test2
EOL
test2:
# after test2
test3: 3
"""
compare_eol(data, exp)
# EOL is no longer necessary
# fixed together with issue # 216
def test_map_set_comment_before_and_after_non_first_key_01(self):
data = load(
"""
xyz:
a: 1 # comment 1
b: 2
test1:
test2:
test3: 3
"""
)
data.yaml_set_comment_before_after_key(
"test1", "before test1 (top level)", after="before test2\n\n"
)
data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
exp = """
xyz:
a: 1 # comment 1
b: 2
# before test1 (top level)
test1:
# before test2
test2:
# after test2
test3: 3
"""
compare(data, exp)
def Xtest_map_set_comment_before_and_after_non_first_key_02(self):
data = load(
"""
xyz:
a: 1 # comment 1
b: 2
test1:
test2:
test3: 3
"""
)
data.yaml_set_comment_before_after_key(
"test1",
"xyz\n\nbefore test1 (top level)",
after="\nbefore test2",
after_indent=4,
)
data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
# EOL is needed here as dedenting gets rid of spaces (as well as does Emacs
exp = """
xyz:
a: 1 # comment 1
b: 2
# xyz
# before test1 (top level)
test1:
EOL
# before test2
test2:
# after test2
test3: 3
"""
compare_eol(data, exp)
def test_map_set_comment_before_and_after_non_first_key_02(self):
data = load(
"""
xyz:
a: 1 # comment 1
b: 2
test1:
test2:
test3: 3
"""
)
data.yaml_set_comment_before_after_key(
"test1",
"xyz\n\nbefore test1 (top level)",
after="\nbefore test2",
after_indent=4,
)
data["test1"]["test2"].yaml_set_start_comment("after test2", indent=4)
exp = """
xyz:
a: 1 # comment 1
b: 2
# xyz
# before test1 (top level)
test1:
# before test2
test2:
# after test2
test3: 3
"""
compare(data, exp)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_comments.py 0000775 0000000 0000000 00000050111 14742310675 0024574 0 ustar 00root root 0000000 0000000 # coding: utf-8
"""
comment testing is all about roundtrips
these can be done in the "old" way by creating a file.data and file.roundtrip
but there is little flexibility in doing that
but some things are not easily tested, eog. how a
roundtrip changes
"""
import pytest
import sys
from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump
class TestComments:
def test_no_end_of_file_eol(self):
"""not excluding comments caused some problems if at the end of
the file without a newline. First error, then included \0 """
x = """\
- europe: 10 # abc"""
round_trip(x, extra="\n")
with pytest.raises(AssertionError):
round_trip(x, extra="a\n")
def test_no_comments(self):
round_trip(
"""
- europe: 10
- usa:
- ohio: 2
- california: 9
"""
)
def test_round_trip_ordering(self):
round_trip(
"""
a: 1
b: 2
c: 3
b1: 2
b2: 2
d: 4
e: 5
f: 6
"""
)
def test_complex(self):
round_trip(
"""
- europe: 10 # top
- usa:
- ohio: 2
- california: 9 # o
"""
)
def test_dropped(self):
s = """\
# comment
scalar
...
"""
round_trip(s, "scalar\n...\n")
def test_main_mapping_begin_end(self):
round_trip(
"""
# C start a
# C start b
abc: 1
ghi: 2
klm: 3
# C end a
# C end b
"""
)
def test_reindent(self):
x = """\
a:
b: # comment 1
c: 1 # comment 2
"""
d = round_trip_load(x)
y = round_trip_dump(d, indent=4)
assert y == dedent(
"""\
a:
b: # comment 1
c: 1 # comment 2
"""
)
def test_main_mapping_begin_end_items_post(self):
round_trip(
"""
# C start a
# C start b
abc: 1 # abc comment
ghi: 2
klm: 3 # klm comment
# C end a
# C end b
"""
)
def test_main_sequence_begin_end(self):
round_trip(
"""
# C start a
# C start b
- abc
- ghi
- klm
# C end a
# C end b
"""
)
def test_main_sequence_begin_end_items_post(self):
round_trip(
"""
# C start a
# C start b
- abc # abc comment
- ghi
- klm # klm comment
# C end a
# C end b
"""
)
def test_main_mapping_begin_end_complex(self):
round_trip(
"""
# C start a
# C start b
abc: 1
ghi: 2
klm:
3a: alpha
3b: beta # it is all greek to me
# C end a
# C end b
"""
)
def test_09(self): # 2.9 from the examples in the spec
s = """\
hr: # 1998 hr ranking
- Mark McGwire
- Sammy Sosa
rbi:
# 1998 rbi ranking
- Sammy Sosa
- Ken Griffey
"""
round_trip(s, indent=4, block_seq_indent=2)
def test_09a(self):
round_trip(
"""
hr: # 1998 hr ranking
- Mark McGwire
- Sammy Sosa
rbi:
# 1998 rbi ranking
- Sammy Sosa
- Ken Griffey
"""
)
def test_simple_map_middle_comment(self):
round_trip(
"""
abc: 1
# C 3a
# C 3b
ghi: 2
"""
)
def test_map_in_map_0(self):
round_trip(
"""
map1: # comment 1
# comment 2
map2:
key1: val1
"""
)
def test_map_in_map_1(self):
# comment is moved from value to key
round_trip(
"""
map1:
# comment 1
map2:
key1: val1
"""
)
def test_application_arguments(self):
# application configur
round_trip(
"""
args:
username: anthon
passwd: secret
fullname: Anthon van der Neut
tmux:
session-name: test
loop:
wait: 10
"""
)
def test_substitute(self):
x = """
args:
username: anthon # name
passwd: secret # password
fullname: Anthon van der Neut
tmux:
session-name: test
loop:
wait: 10
"""
data = round_trip_load(x)
data["args"]["passwd"] = "deleted password"
# note the requirement to add spaces for alignment of comment
x = x.replace(": secret ", ": deleted password")
assert round_trip_dump(data) == dedent(x)
def test_set_comment(self):
round_trip(
"""
!!set
# the beginning
? a
# next one is B (lowercase)
? b # You see? Promised you.
? c
# this is the end
"""
)
def test_omap_comment_roundtrip(self):
round_trip(
"""
!!omap
- a: 1
- b: 2 # two
- c: 3 # three
- d: 4
"""
)
def test_omap_comment_roundtrip_pre_comment(self):
round_trip(
"""
!!omap
- a: 1
- b: 2 # two
- c: 3 # three
# last one
- d: 4
"""
)
def test_non_ascii(self):
round_trip(
"""
verbosity: 1 # 0 is minimal output, -1 none
base_url: http://gopher.net
special_indices: [1, 5, 8]
also_special:
- a
- 19
- 32
asia and europe: &asia_europe
Turkey: Ankara
Russia: Moscow
countries:
Asia:
<<: *asia_europe
Japan: Tokyo # 東京
Europe:
<<: *asia_europe
Spain: Madrid
Italy: Rome
"""
)
def test_dump_utf8(self):
import srsly.ruamel_yaml # NOQA
x = dedent(
"""\
ab:
- x # comment
- y # more comment
"""
)
data = round_trip_load(x)
dumper = srsly.ruamel_yaml.RoundTripDumper
for utf in [True, False]:
y = srsly.ruamel_yaml.dump(
data, default_flow_style=False, Dumper=dumper, allow_unicode=utf
)
assert y == x
def test_dump_unicode_utf8(self):
import srsly.ruamel_yaml # NOQA
x = dedent(
u"""\
ab:
- x # comment
- y # more comment
"""
)
data = round_trip_load(x)
dumper = srsly.ruamel_yaml.RoundTripDumper
for utf in [True, False]:
y = srsly.ruamel_yaml.dump(
data, default_flow_style=False, Dumper=dumper, allow_unicode=utf
)
assert y == x
def test_mlget_00(self):
x = """\
a:
- b:
c: 42
- d:
f: 196
e:
g: 3.14
"""
d = round_trip_load(x)
assert d.mlget(["a", 1, "d", "f"], list_ok=True) == 196
with pytest.raises(AssertionError):
d.mlget(["a", 1, "d", "f"]) == 196
class TestInsertPopList:
"""list insertion is more complex than dict insertion, as you
need to move the values to subsequent keys on insert"""
@property
def ins(self):
return """\
ab:
- a # a
- b # b
- c
- d # d
de:
- 1
- 2
"""
def test_insert_0(self):
d = round_trip_load(self.ins)
d["ab"].insert(0, "xyz")
y = round_trip_dump(d, indent=2)
assert y == dedent(
"""\
ab:
- xyz
- a # a
- b # b
- c
- d # d
de:
- 1
- 2
"""
)
def test_insert_1(self):
d = round_trip_load(self.ins)
d["ab"].insert(4, "xyz")
y = round_trip_dump(d, indent=2)
assert y == dedent(
"""\
ab:
- a # a
- b # b
- c
- d # d
- xyz
de:
- 1
- 2
"""
)
def test_insert_2(self):
d = round_trip_load(self.ins)
d["ab"].insert(1, "xyz")
y = round_trip_dump(d, indent=2)
assert y == dedent(
"""\
ab:
- a # a
- xyz
- b # b
- c
- d # d
de:
- 1
- 2
"""
)
def test_pop_0(self):
d = round_trip_load(self.ins)
d["ab"].pop(0)
y = round_trip_dump(d, indent=2)
print(y)
assert y == dedent(
"""\
ab:
- b # b
- c
- d # d
de:
- 1
- 2
"""
)
def test_pop_1(self):
d = round_trip_load(self.ins)
d["ab"].pop(1)
y = round_trip_dump(d, indent=2)
print(y)
assert y == dedent(
"""\
ab:
- a # a
- c
- d # d
de:
- 1
- 2
"""
)
def test_pop_2(self):
d = round_trip_load(self.ins)
d["ab"].pop(2)
y = round_trip_dump(d, indent=2)
print(y)
assert y == dedent(
"""\
ab:
- a # a
- b # b
- d # d
de:
- 1
- 2
"""
)
def test_pop_3(self):
d = round_trip_load(self.ins)
d["ab"].pop(3)
y = round_trip_dump(d, indent=2)
print(y)
assert y == dedent(
"""\
ab:
- a # a
- b # b
- c
de:
- 1
- 2
"""
)
# inspired by demux' question on stackoverflow
# http://stackoverflow.com/a/36970608/1307905
class TestInsertInMapping:
@property
def ins(self):
return """\
first_name: Art
occupation: Architect # This is an occupation comment
about: Art Vandelay is a fictional character that George invents...
"""
def test_insert_at_pos_1(self):
d = round_trip_load(self.ins)
d.insert(1, "last name", "Vandelay", comment="new key")
y = round_trip_dump(d)
print(y)
assert y == dedent(
"""\
first_name: Art
last name: Vandelay # new key
occupation: Architect # This is an occupation comment
about: Art Vandelay is a fictional character that George invents...
"""
)
def test_insert_at_pos_0(self):
d = round_trip_load(self.ins)
d.insert(0, "last name", "Vandelay", comment="new key")
y = round_trip_dump(d)
print(y)
assert y == dedent(
"""\
last name: Vandelay # new key
first_name: Art
occupation: Architect # This is an occupation comment
about: Art Vandelay is a fictional character that George invents...
"""
)
def test_insert_at_pos_3(self):
# much more simple if done with appending.
d = round_trip_load(self.ins)
d.insert(3, "last name", "Vandelay", comment="new key")
y = round_trip_dump(d)
print(y)
assert y == dedent(
"""\
first_name: Art
occupation: Architect # This is an occupation comment
about: Art Vandelay is a fictional character that George invents...
last name: Vandelay # new key
"""
)
class TestCommentedMapMerge:
def test_in_operator(self):
data = round_trip_load(
"""
x: &base
a: 1
b: 2
c: 3
y:
<<: *base
k: 4
l: 5
"""
)
assert data["x"]["a"] == 1
assert "a" in data["x"]
assert data["y"]["a"] == 1
assert "a" in data["y"]
def test_issue_60(self):
data = round_trip_load(
"""
x: &base
a: 1
y:
<<: *base
"""
)
assert data["x"]["a"] == 1
assert data["y"]["a"] == 1
if sys.version_info >= (3, 12):
assert str(data["y"]) == """ordereddict({'a': 1})"""
else:
assert str(data["y"]) == """ordereddict([('a', 1)])"""
def test_issue_60_1(self):
data = round_trip_load(
"""
x: &base
a: 1
y:
<<: *base
b: 2
"""
)
assert data["x"]["a"] == 1
assert data["y"]["a"] == 1
if sys.version_info >= (3, 12):
assert str(data["y"]) == """ordereddict({'b': 2, 'a': 1})"""
else:
assert str(data["y"]) == """ordereddict([('b', 2), ('a', 1)])"""
class TestEmptyLines:
# prompted by issue 46 from Alex Harvey
def test_issue_46(self):
yaml_str = dedent(
"""\
---
# Please add key/value pairs in alphabetical order
aws_s3_bucket: 'mys3bucket'
jenkins_ad_credentials:
bind_name: 'CN=svc-AAA-BBB-T,OU=Example,DC=COM,DC=EXAMPLE,DC=Local'
bind_pass: 'xxxxyyyy{'
"""
)
d = round_trip_load(yaml_str, preserve_quotes=True)
y = round_trip_dump(d, explicit_start=True)
assert yaml_str == y
def test_multispace_map(self):
round_trip(
"""
a: 1x
b: 2x
c: 3x
d: 4x
"""
)
@pytest.mark.xfail(strict=True)
def test_multispace_map_initial(self):
round_trip(
"""
a: 1x
b: 2x
c: 3x
d: 4x
"""
)
def test_embedded_map(self):
round_trip(
"""
- a: 1y
b: 2y
c: 3y
"""
)
def test_toplevel_seq(self):
round_trip(
"""\
- 1
- 2
- 3
"""
)
def test_embedded_seq(self):
round_trip(
"""
a:
b:
- 1
- 2
- 3
"""
)
def test_line_with_only_spaces(self):
# issue 54
yaml_str = "---\n\na: 'x'\n \nb: y\n"
d = round_trip_load(yaml_str, preserve_quotes=True)
y = round_trip_dump(d, explicit_start=True)
stripped = ""
for line in yaml_str.splitlines():
stripped += line.rstrip() + "\n"
print(line + "$")
assert stripped == y
def test_some_eol_spaces(self):
# spaces after tokens and on empty lines
yaml_str = '--- \n \na: "x" \n \nb: y \n'
d = round_trip_load(yaml_str, preserve_quotes=True)
y = round_trip_dump(d, explicit_start=True)
stripped = ""
for line in yaml_str.splitlines():
stripped += line.rstrip() + "\n"
print(line + "$")
assert stripped == y
def test_issue_54_not_ok(self):
yaml_str = dedent(
"""\
toplevel:
# some comment
sublevel: 300
"""
)
d = round_trip_load(yaml_str)
print(d.ca)
y = round_trip_dump(d, indent=4)
print(y.replace("\n", "$\n"))
assert yaml_str == y
def test_issue_54_ok(self):
yaml_str = dedent(
"""\
toplevel:
# some comment
sublevel: 300
"""
)
d = round_trip_load(yaml_str)
y = round_trip_dump(d, indent=4)
assert yaml_str == y
def test_issue_93(self):
round_trip(
"""\
a:
b:
- c1: cat # a1
# my comment on catfish
- c2: catfish # a2
"""
)
def test_issue_93_00(self):
round_trip(
"""\
a:
- - c1: cat # a1
# my comment on catfish
- c2: catfish # a2
"""
)
def test_issue_93_01(self):
round_trip(
"""\
- - c1: cat # a1
# my comment on catfish
- c2: catfish # a2
"""
)
def test_issue_93_02(self):
# never failed as there is no indent
round_trip(
"""\
- c1: cat
# my comment on catfish
- c2: catfish
"""
)
def test_issue_96(self):
# inserted extra line on trailing spaces
round_trip(
"""\
a:
b:
c: c_val
d:
e:
g: g_val
"""
)
class TestUnicodeComments:
@pytest.mark.skipif(sys.version_info < (2, 7), reason="wide unicode")
def test_issue_55(self): # reported by Haraguroicha Hsu
round_trip(
"""\
name: TEST
description: test using
author: Harguroicha
sql:
command: |-
select name from testtbl where no = :no
ci-test:
- :no: 04043709 # 小花
- :no: 05161690 # 茶
- :no: 05293147 # 〇𤋥川
- :no: 05338777 # 〇〇啓
- :no: 05273867 # 〇
- :no: 05205786 # 〇𤦌
"""
)
class TestEmptyValueBeforeComments:
def test_issue_25a(self):
round_trip(
"""\
- a: b
c: d
d: # foo
- e: f
"""
)
def test_issue_25a1(self):
round_trip(
"""\
- a: b
c: d
d: # foo
e: f
"""
)
def test_issue_25b(self):
round_trip(
"""\
var1: #empty
var2: something #notempty
"""
)
def test_issue_25c(self):
round_trip(
"""\
params:
a: 1 # comment a
b: # comment b
c: 3 # comment c
"""
)
def test_issue_25c1(self):
round_trip(
"""\
params:
a: 1 # comment a
b: # comment b
# extra
c: 3 # comment c
"""
)
def test_issue_25_00(self):
round_trip(
"""\
params:
a: 1 # comment a
b: # comment b
"""
)
def test_issue_25_01(self):
round_trip(
"""\
a: # comment 1
# comment 2
- b: # comment 3
c: 1 # comment 4
"""
)
def test_issue_25_02(self):
round_trip(
"""\
a: # comment 1
# comment 2
- b: 2 # comment 3
"""
)
def test_issue_25_03(self):
s = """\
a: # comment 1
# comment 2
- b: 2 # comment 3
"""
round_trip(s, indent=4, block_seq_indent=2)
def test_issue_25_04(self):
round_trip(
"""\
a: # comment 1
# comment 2
b: 1 # comment 3
"""
)
def test_flow_seq_within_seq(self):
round_trip(
"""\
# comment 1
- a
- b
# comment 2
- c
- d
# comment 3
- [e]
- f
# comment 4
- []
"""
)
test_block_scalar_commented_line_template = """\
y: p
# Some comment
a: |
x
{}b: y
"""
class TestBlockScalarWithComments:
# issue 99 reported by Colm O'Connor
def test_scalar_with_comments(self):
import srsly.ruamel_yaml # NOQA
for x in [
"",
"\n",
"\n# Another comment\n",
"\n\n",
"\n\n# abc\n#xyz\n",
"\n\n# abc\n#xyz\n",
"# abc\n\n#xyz\n",
"\n\n # abc\n #xyz\n",
]:
commented_line = test_block_scalar_commented_line_template.format(x)
data = srsly.ruamel_yaml.round_trip_load(commented_line)
assert srsly.ruamel_yaml.round_trip_dump(data) == commented_line
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_contextmanager.py 0000775 0000000 0000000 00000005364 14742310675 0026000 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
"""
testing of anchors and the aliases referring to them
"""
import sys
import pytest
single_doc = """\
- a: 1
- b:
- 2
- 3
"""
single_data = [dict(a=1), dict(b=[2, 3])]
multi_doc = """\
---
- abc
- xyz
---
- a: 1
- b:
- 2
- 3
"""
multi_doc_data = [["abc", "xyz"], single_data]
def get_yaml():
from srsly.ruamel_yaml import YAML
return YAML()
class TestOldStyle:
def test_single_load(self):
d = get_yaml().load(single_doc)
print(d)
print(type(d[0]))
assert d == single_data
def test_single_load_no_arg(self):
with pytest.raises(TypeError):
assert get_yaml().load() == single_data
def test_multi_load(self):
data = list(get_yaml().load_all(multi_doc))
assert data == multi_doc_data
def test_single_dump(self, capsys):
get_yaml().dump(single_data, sys.stdout)
out, err = capsys.readouterr()
assert out == single_doc
def test_multi_dump(self, capsys):
yaml = get_yaml()
yaml.explicit_start = True
yaml.dump_all(multi_doc_data, sys.stdout)
out, err = capsys.readouterr()
assert out == multi_doc
class TestContextManager:
def test_single_dump(self, capsys):
from srsly.ruamel_yaml import YAML
with YAML(output=sys.stdout) as yaml:
yaml.dump(single_data)
out, err = capsys.readouterr()
print(err)
assert out == single_doc
def test_multi_dump(self, capsys):
from srsly.ruamel_yaml import YAML
with YAML(output=sys.stdout) as yaml:
yaml.explicit_start = True
yaml.dump(multi_doc_data[0])
yaml.dump(multi_doc_data[1])
out, err = capsys.readouterr()
print(err)
assert out == multi_doc
# input is not as simple with a context manager
# you need to indicate what you expect hence load and load_all
# @pytest.mark.xfail(strict=True)
# def test_single_load(self):
# from srsly.ruamel_yaml import YAML
# with YAML(input=single_doc) as yaml:
# assert yaml.load() == single_data
#
# @pytest.mark.xfail(strict=True)
# def test_multi_load(self):
# from srsly.ruamel_yaml import YAML
# with YAML(input=multi_doc) as yaml:
# for idx, data in enumerate(yaml.load()):
# assert data == multi_doc_data[0]
def test_roundtrip(self, capsys):
from srsly.ruamel_yaml import YAML
with YAML(output=sys.stdout) as yaml:
yaml.explicit_start = True
for data in yaml.load_all(multi_doc):
yaml.dump(data)
out, err = capsys.readouterr()
print(err)
assert out == multi_doc
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_copy.py 0000775 0000000 0000000 00000007103 14742310675 0023724 0 ustar 00root root 0000000 0000000 # coding: utf-8
"""
Testing copy and deepcopy, instigated by Issue 84 (Peter Amstutz)
"""
import copy
import pytest # NOQA
from .roundtrip import dedent, round_trip_load, round_trip_dump
class TestDeepCopy:
def test_preserve_flow_style_simple(self):
x = dedent(
"""\
{foo: bar, baz: quux}
"""
)
data = round_trip_load(x)
data_copy = copy.deepcopy(data)
y = round_trip_dump(data_copy)
print("x [{}]".format(x))
print("y [{}]".format(y))
assert y == x
assert data.fa.flow_style() == data_copy.fa.flow_style()
def test_deepcopy_flow_style_nested_dict(self):
x = dedent(
"""\
a: {foo: bar, baz: quux}
"""
)
data = round_trip_load(x)
assert data["a"].fa.flow_style() is True
data_copy = copy.deepcopy(data)
assert data_copy["a"].fa.flow_style() is True
data_copy["a"].fa.set_block_style()
assert data["a"].fa.flow_style() != data_copy["a"].fa.flow_style()
assert data["a"].fa._flow_style is True
assert data_copy["a"].fa._flow_style is False
y = round_trip_dump(data_copy)
print("x [{}]".format(x))
print("y [{}]".format(y))
assert y == dedent(
"""\
a:
foo: bar
baz: quux
"""
)
def test_deepcopy_flow_style_nested_list(self):
x = dedent(
"""\
a: [1, 2, 3]
"""
)
data = round_trip_load(x)
assert data["a"].fa.flow_style() is True
data_copy = copy.deepcopy(data)
assert data_copy["a"].fa.flow_style() is True
data_copy["a"].fa.set_block_style()
assert data["a"].fa.flow_style() != data_copy["a"].fa.flow_style()
assert data["a"].fa._flow_style is True
assert data_copy["a"].fa._flow_style is False
y = round_trip_dump(data_copy)
print("x [{}]".format(x))
print("y [{}]".format(y))
assert y == dedent(
"""\
a:
- 1
- 2
- 3
"""
)
class TestCopy:
def test_copy_flow_style_nested_dict(self):
x = dedent(
"""\
a: {foo: bar, baz: quux}
"""
)
data = round_trip_load(x)
assert data["a"].fa.flow_style() is True
data_copy = copy.copy(data)
assert data_copy["a"].fa.flow_style() is True
data_copy["a"].fa.set_block_style()
assert data["a"].fa.flow_style() == data_copy["a"].fa.flow_style()
assert data["a"].fa._flow_style is False
assert data_copy["a"].fa._flow_style is False
y = round_trip_dump(data_copy)
z = round_trip_dump(data)
assert y == z
assert y == dedent(
"""\
a:
foo: bar
baz: quux
"""
)
def test_copy_flow_style_nested_list(self):
x = dedent(
"""\
a: [1, 2, 3]
"""
)
data = round_trip_load(x)
assert data["a"].fa.flow_style() is True
data_copy = copy.copy(data)
assert data_copy["a"].fa.flow_style() is True
data_copy["a"].fa.set_block_style()
assert data["a"].fa.flow_style() == data_copy["a"].fa.flow_style()
assert data["a"].fa._flow_style is False
assert data_copy["a"].fa._flow_style is False
y = round_trip_dump(data_copy)
print("x [{}]".format(x))
print("y [{}]".format(y))
assert y == dedent(
"""\
a:
- 1
- 2
- 3
"""
)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_datetime.py 0000775 0000000 0000000 00000007313 14742310675 0024551 0 ustar 00root root 0000000 0000000 # coding: utf-8
"""
http://yaml.org/type/timestamp.html specifies the regexp to use
for datetime.date and datetime.datetime construction. Date is simple
but datetime can have 'T' or 't' as well as 'Z' or a timezone offset (in
hours and minutes). This information was originally used to create
a UTC datetime and then discarded
examples from the above:
canonical: 2001-12-15T02:59:43.1Z
valid iso8601: 2001-12-14t21:59:43.10-05:00
space separated: 2001-12-14 21:59:43.10 -5
no time zone (Z): 2001-12-15 2:59:43.10
date (00:00:00Z): 2002-12-14
Please note that a fraction can only be included if not equal to 0
"""
import copy
import pytest # NOQA
from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
class TestDateTime:
def test_date_only(self):
inp = """
- 2011-10-02
"""
exp = """
- 2011-10-02
"""
round_trip(inp, exp)
def test_zero_fraction(self):
inp = """
- 2011-10-02 16:45:00.0
"""
exp = """
- 2011-10-02 16:45:00
"""
round_trip(inp, exp)
def test_long_fraction(self):
inp = """
- 2011-10-02 16:45:00.1234 # expand with zeros
- 2011-10-02 16:45:00.123456
- 2011-10-02 16:45:00.12345612 # round to microseconds
- 2011-10-02 16:45:00.1234565 # round up
- 2011-10-02 16:45:00.12345678 # round up
"""
exp = """
- 2011-10-02 16:45:00.123400 # expand with zeros
- 2011-10-02 16:45:00.123456
- 2011-10-02 16:45:00.123456 # round to microseconds
- 2011-10-02 16:45:00.123457 # round up
- 2011-10-02 16:45:00.123457 # round up
"""
round_trip(inp, exp)
def test_canonical(self):
inp = """
- 2011-10-02T16:45:00.1Z
"""
exp = """
- 2011-10-02T16:45:00.100000Z
"""
round_trip(inp, exp)
def test_spaced_timezone(self):
inp = """
- 2011-10-02T11:45:00 -5
"""
exp = """
- 2011-10-02T11:45:00-5
"""
round_trip(inp, exp)
def test_normal_timezone(self):
round_trip(
"""
- 2011-10-02T11:45:00-5
- 2011-10-02 11:45:00-5
- 2011-10-02T11:45:00-05:00
- 2011-10-02 11:45:00-05:00
"""
)
def test_no_timezone(self):
inp = """
- 2011-10-02 6:45:00
"""
exp = """
- 2011-10-02 06:45:00
"""
round_trip(inp, exp)
def test_explicit_T(self):
inp = """
- 2011-10-02T16:45:00
"""
exp = """
- 2011-10-02T16:45:00
"""
round_trip(inp, exp)
def test_explicit_t(self): # to upper
inp = """
- 2011-10-02t16:45:00
"""
exp = """
- 2011-10-02T16:45:00
"""
round_trip(inp, exp)
def test_no_T_multi_space(self):
inp = """
- 2011-10-02 16:45:00
"""
exp = """
- 2011-10-02 16:45:00
"""
round_trip(inp, exp)
def test_iso(self):
round_trip(
"""
- 2011-10-02T15:45:00+01:00
"""
)
def test_zero_tz(self):
round_trip(
"""
- 2011-10-02T15:45:00+0
"""
)
def test_issue_45(self):
round_trip(
"""
dt: 2016-08-19T22:45:47Z
"""
)
def test_deepcopy_datestring(self):
# reported by Quuxplusone, http://stackoverflow.com/a/41577841/1307905
x = dedent(
"""\
foo: 2016-10-12T12:34:56
"""
)
data = copy.deepcopy(round_trip_load(x))
assert round_trip_dump(data) == x
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_deprecation.py 0000775 0000000 0000000 00000000541 14742310675 0025246 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
import sys
import pytest # NOQA
@pytest.mark.skipif(sys.version_info < (3, 7) or sys.version_info >= (3, 9),
reason='collections not available?')
def test_collections_deprecation():
with pytest.warns(DeprecationWarning):
from collections import Hashable # NOQA
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_documents.py 0000775 0000000 0000000 00000003443 14742310675 0024756 0 ustar 00root root 0000000 0000000 # coding: utf-8
import pytest # NOQA
from .roundtrip import round_trip, round_trip_load_all
class TestDocument:
def test_single_doc_begin_end(self):
inp = """\
---
- a
- b
...
"""
round_trip(inp, explicit_start=True, explicit_end=True)
def test_multi_doc_begin_end(self):
from srsly.ruamel_yaml import dump_all, RoundTripDumper
inp = """\
---
- a
...
---
- b
...
"""
docs = list(round_trip_load_all(inp))
assert docs == [["a"], ["b"]]
out = dump_all(
docs, Dumper=RoundTripDumper, explicit_start=True, explicit_end=True
)
assert out == "---\n- a\n...\n---\n- b\n...\n"
def test_multi_doc_no_start(self):
inp = """\
- a
...
---
- b
...
"""
docs = list(round_trip_load_all(inp))
assert docs == [["a"], ["b"]]
def test_multi_doc_no_end(self):
inp = """\
- a
---
- b
"""
docs = list(round_trip_load_all(inp))
assert docs == [["a"], ["b"]]
def test_multi_doc_ends_only(self):
# this is ok in 1.2
inp = """\
- a
...
- b
...
"""
docs = list(round_trip_load_all(inp, version=(1, 2)))
assert docs == [["a"], ["b"]]
def test_multi_doc_ends_only_1_1(self):
from srsly.ruamel_yaml import parser
# this is not ok in 1.1
with pytest.raises(parser.ParserError):
inp = """\
- a
...
- b
...
"""
docs = list(round_trip_load_all(inp, version=(1, 1)))
assert docs == [["a"], ["b"]] # not True, but not reached
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_fail.py 0000775 0000000 0000000 00000014222 14742310675 0023665 0 ustar 00root root 0000000 0000000 # coding: utf-8
# there is some work to do
# provide a failing test xyz and a non-failing xyz_no_fail ( to see
# what the current failing output is.
# on fix of srsly.ruamel_yaml, move the marked test to the appropriate test (without mark)
# and remove remove the xyz_no_fail
import pytest
from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump
class TestCommentFailures:
@pytest.mark.xfail(strict=True)
def test_set_comment_before_tag(self):
# no comments before tags
round_trip(
"""
# the beginning
!!set
# or this one?
? a
# next one is B (lowercase)
? b # You see? Promised you.
? c
# this is the end
"""
)
def test_set_comment_before_tag_no_fail(self):
# no comments before tags
inp = """
# the beginning
!!set
# or this one?
? a
# next one is B (lowercase)
? b # You see? Promised you.
? c
# this is the end
"""
assert round_trip_dump(round_trip_load(inp)) == dedent(
"""
!!set
# or this one?
? a
# next one is B (lowercase)
? b # You see? Promised you.
? c
# this is the end
"""
)
@pytest.mark.xfail(strict=True)
def test_comment_dash_line(self):
round_trip(
"""
- # abc
a: 1
b: 2
"""
)
def test_comment_dash_line_fail(self):
x = """
- # abc
a: 1
b: 2
"""
data = round_trip_load(x)
# this is not nice
assert round_trip_dump(data) == dedent(
"""
# abc
- a: 1
b: 2
"""
)
class TestIndentFailures:
@pytest.mark.xfail(strict=True)
def test_indent_not_retained(self):
round_trip(
"""
verbosity: 1 # 0 is minimal output, -1 none
base_url: http://gopher.net
special_indices: [1, 5, 8]
also_special:
- a
- 19
- 32
asia and europe: &asia_europe
Turkey: Ankara
Russia: Moscow
countries:
Asia:
<<: *asia_europe
Japan: Tokyo # 東京
Europe:
<<: *asia_europe
Spain: Madrid
Italy: Rome
Antarctica:
- too cold
"""
)
def test_indent_not_retained_no_fail(self):
inp = """
verbosity: 1 # 0 is minimal output, -1 none
base_url: http://gopher.net
special_indices: [1, 5, 8]
also_special:
- a
- 19
- 32
asia and europe: &asia_europe
Turkey: Ankara
Russia: Moscow
countries:
Asia:
<<: *asia_europe
Japan: Tokyo # 東京
Europe:
<<: *asia_europe
Spain: Madrid
Italy: Rome
Antarctica:
- too cold
"""
assert round_trip_dump(round_trip_load(inp), indent=4) == dedent(
"""
verbosity: 1 # 0 is minimal output, -1 none
base_url: http://gopher.net
special_indices: [1, 5, 8]
also_special:
- a
- 19
- 32
asia and europe: &asia_europe
Turkey: Ankara
Russia: Moscow
countries:
Asia:
<<: *asia_europe
Japan: Tokyo # 東京
Europe:
<<: *asia_europe
Spain: Madrid
Italy: Rome
Antarctica:
- too cold
"""
)
def Xtest_indent_top_level_no_fail(self):
inp = """
- a:
- b
"""
round_trip(inp, indent=4)
class TestTagFailures:
@pytest.mark.xfail(strict=True)
def test_standard_short_tag(self):
round_trip(
"""\
!!map
name: Anthon
location: Germany
language: python
"""
)
def test_standard_short_tag_no_fail(self):
inp = """
!!map
name: Anthon
location: Germany
language: python
"""
exp = """
name: Anthon
location: Germany
language: python
"""
assert round_trip_dump(round_trip_load(inp)) == dedent(exp)
class TestFlowValues:
def test_flow_value_with_colon(self):
inp = """\
{a: bcd:efg}
"""
round_trip(inp)
def test_flow_value_with_colon_quoted(self):
inp = """\
{a: 'bcd:efg'}
"""
round_trip(inp, preserve_quotes=True)
class TestMappingKey:
def test_simple_mapping_key(self):
inp = """\
{a: 1, b: 2}: hello world
"""
round_trip(inp, preserve_quotes=True, dump_data=False)
def test_set_simple_mapping_key(self):
from srsly.ruamel_yaml.comments import CommentedKeyMap
d = {CommentedKeyMap([("a", 1), ("b", 2)]): "hello world"}
exp = dedent(
"""\
{a: 1, b: 2}: hello world
"""
)
assert round_trip_dump(d) == exp
def test_change_key_simple_mapping_key(self):
from srsly.ruamel_yaml.comments import CommentedKeyMap
inp = """\
{a: 1, b: 2}: hello world
"""
d = round_trip_load(inp, preserve_quotes=True)
d[CommentedKeyMap([("b", 1), ("a", 2)])] = d.pop(
CommentedKeyMap([("a", 1), ("b", 2)])
)
exp = dedent(
"""\
{b: 1, a: 2}: hello world
"""
)
assert round_trip_dump(d) == exp
def test_change_value_simple_mapping_key(self):
from srsly.ruamel_yaml.comments import CommentedKeyMap
inp = """\
{a: 1, b: 2}: hello world
"""
d = round_trip_load(inp, preserve_quotes=True)
d = {CommentedKeyMap([("a", 1), ("b", 2)]): "goodbye"}
exp = dedent(
"""\
{a: 1, b: 2}: goodbye
"""
)
assert round_trip_dump(d) == exp
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_float.py 0000775 0000000 0000000 00000004122 14742310675 0024055 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division, unicode_literals
import pytest # NOQA
from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
# http://yaml.org/type/int.html is where underscores in integers are defined
class TestFloat:
def test_round_trip_non_exp(self):
data = round_trip(
"""\
- 1.0
- 1.00
- 23.100
- -1.0
- -1.00
- -23.100
- 42.
- -42.
- +42.
- .5
- +.5
- -.5
"""
)
print(data)
assert 0.999 < data[0] < 1.001
assert 0.999 < data[1] < 1.001
assert 23.099 < data[2] < 23.101
assert 0.999 < -data[3] < 1.001
assert 0.999 < -data[4] < 1.001
assert 23.099 < -data[5] < 23.101
assert 41.999 < data[6] < 42.001
assert 41.999 < -data[7] < 42.001
assert 41.999 < data[8] < 42.001
assert 0.49 < data[9] < 0.51
assert 0.49 < data[10] < 0.51
assert -0.51 < data[11] < -0.49
def test_round_trip_zeros_0(self):
data = round_trip(
"""\
- 0.
- +0.
- -0.
- 0.0
- +0.0
- -0.0
- 0.00
- +0.00
- -0.00
"""
)
print(data)
for d in data:
assert -0.00001 < d < 0.00001
def Xtest_round_trip_non_exp_trailing_dot(self):
data = round_trip(
"""\
"""
)
print(data)
def test_yaml_1_1_no_dot(self):
from srsly.ruamel_yaml.error import MantissaNoDotYAML1_1Warning
with pytest.warns(MantissaNoDotYAML1_1Warning):
round_trip_load(
"""\
%YAML 1.1
---
- 1e6
"""
)
class TestCalculations(object):
def test_mul_00(self):
# issue 149 reported by jan.brezina@tul.cz
d = round_trip_load(
"""\
- 0.1
"""
)
d[0] *= -1
x = round_trip_dump(d)
assert x == "- -0.1\n"
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_flowsequencekey.py 0000775 0000000 0000000 00000000747 14742310675 0026172 0 ustar 00root root 0000000 0000000 # coding: utf-8
"""
test flow style sequences as keys roundtrip
"""
# import pytest
from .roundtrip import round_trip # , dedent, round_trip_load, round_trip_dump
class TestFlowStyleSequenceKey:
def test_so_39595807(self):
inp = """\
%YAML 1.2
---
[2, 3, 4]:
a:
- 1
- 2
b: Hello World!
c: 'Voilà!'
"""
round_trip(inp, preserve_quotes=True, explicit_start=True, version=(1, 2))
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_indentation.py 0000775 0000000 0000000 00000020500 14742310675 0025262 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import pytest # NOQA
from .roundtrip import round_trip, round_trip_load, round_trip_dump, dedent, YAML
def rt(s):
import srsly.ruamel_yaml
res = srsly.ruamel_yaml.dump(
srsly.ruamel_yaml.load(s, Loader=srsly.ruamel_yaml.RoundTripLoader),
Dumper=srsly.ruamel_yaml.RoundTripDumper,
)
return res.strip() + "\n"
class TestIndent:
def test_roundtrip_inline_list(self):
s = "a: [a, b, c]\n"
output = rt(s)
assert s == output
def test_roundtrip_mapping_of_inline_lists(self):
s = dedent(
"""\
a: [a, b, c]
j: [k, l, m]
"""
)
output = rt(s)
assert s == output
def test_roundtrip_mapping_of_inline_lists_comments(self):
s = dedent(
"""\
# comment A
a: [a, b, c]
# comment B
j: [k, l, m]
"""
)
output = rt(s)
assert s == output
def test_roundtrip_mapping_of_inline_sequence_eol_comments(self):
s = dedent(
"""\
# comment A
a: [a, b, c] # comment B
j: [k, l, m] # comment C
"""
)
output = rt(s)
assert s == output
# first test by explicitly setting flow style
def test_added_inline_list(self):
import srsly.ruamel_yaml
s1 = dedent(
"""
a:
- b
- c
- d
"""
)
s = "a: [b, c, d]\n"
data = srsly.ruamel_yaml.load(s1, Loader=srsly.ruamel_yaml.RoundTripLoader)
val = data["a"]
val.fa.set_flow_style()
# print(type(val), '_yaml_format' in dir(val))
output = srsly.ruamel_yaml.dump(data, Dumper=srsly.ruamel_yaml.RoundTripDumper)
assert s == output
# ############ flow mappings
def test_roundtrip_flow_mapping(self):
import srsly.ruamel_yaml
s = dedent(
"""\
- {a: 1, b: hallo}
- {j: fka, k: 42}
"""
)
data = srsly.ruamel_yaml.load(s, Loader=srsly.ruamel_yaml.RoundTripLoader)
output = srsly.ruamel_yaml.dump(data, Dumper=srsly.ruamel_yaml.RoundTripDumper)
assert s == output
def test_roundtrip_sequence_of_inline_mappings_eol_comments(self):
s = dedent(
"""\
# comment A
- {a: 1, b: hallo} # comment B
- {j: fka, k: 42} # comment C
"""
)
output = rt(s)
assert s == output
def test_indent_top_level(self):
inp = """
- a:
- b
"""
round_trip(inp, indent=4)
def test_set_indent_5_block_list_indent_1(self):
inp = """
a:
- b: c
- 1
- d:
- 2
"""
round_trip(inp, indent=5, block_seq_indent=1)
def test_set_indent_4_block_list_indent_2(self):
inp = """
a:
- b: c
- 1
- d:
- 2
"""
round_trip(inp, indent=4, block_seq_indent=2)
def test_set_indent_3_block_list_indent_0(self):
inp = """
a:
- b: c
- 1
- d:
- 2
"""
round_trip(inp, indent=3, block_seq_indent=0)
def Xtest_set_indent_3_block_list_indent_2(self):
inp = """
a:
-
b: c
-
1
-
d:
-
2
"""
round_trip(inp, indent=3, block_seq_indent=2)
def test_set_indent_3_block_list_indent_2(self):
inp = """
a:
- b: c
- 1
- d:
- 2
"""
round_trip(inp, indent=3, block_seq_indent=2)
def Xtest_set_indent_2_block_list_indent_2(self):
inp = """
a:
-
b: c
-
1
-
d:
-
2
"""
round_trip(inp, indent=2, block_seq_indent=2)
# this is how it should be: block_seq_indent stretches the indent
def test_set_indent_2_block_list_indent_2(self):
inp = """
a:
- b: c
- 1
- d:
- 2
"""
round_trip(inp, indent=2, block_seq_indent=2)
# have to set indent!
def test_roundtrip_four_space_indents(self):
# fmt: off
s = (
'a:\n'
'- foo\n'
'- bar\n'
)
# fmt: on
round_trip(s, indent=4)
def test_roundtrip_four_space_indents_no_fail(self):
inp = """
a:
- foo
- bar
"""
exp = """
a:
- foo
- bar
"""
assert round_trip_dump(round_trip_load(inp)) == dedent(exp)
class TestYpkgIndent:
def test_00(self):
inp = """
name : nano
version : 2.3.2
release : 1
homepage : http://www.nano-editor.org
source :
- http://www.nano-editor.org/dist/v2.3/nano-2.3.2.tar.gz : ff30924807ea289f5b60106be8
license : GPL-2.0
summary : GNU nano is an easy-to-use text editor
builddeps :
- ncurses-devel
description: |
GNU nano is an easy-to-use text editor originally designed
as a replacement for Pico, the ncurses-based editor from the non-free mailer
package Pine (itself now available under the Apache License as Alpine).
"""
round_trip(
inp,
indent=4,
block_seq_indent=2,
top_level_colon_align=True,
prefix_colon=" ",
)
def guess(s):
from srsly.ruamel_yaml.util import load_yaml_guess_indent
x, y, z = load_yaml_guess_indent(dedent(s))
return y, z
class TestGuessIndent:
def test_guess_20(self):
inp = """\
a:
- 1
"""
assert guess(inp) == (2, 0)
def test_guess_42(self):
inp = """\
a:
- 1
"""
assert guess(inp) == (4, 2)
def test_guess_42a(self):
# block seq indent prevails over nested key indent level
inp = """\
b:
a:
- 1
"""
assert guess(inp) == (4, 2)
def test_guess_3None(self):
inp = """\
b:
a: 1
"""
assert guess(inp) == (3, None)
class TestSeparateMapSeqIndents:
# using uncommon 6 indent with 3 push in as 2 push in automatically
# gets you 4 indent even if not set
def test_00(self):
# old style
yaml = YAML()
yaml.indent = 6
yaml.block_seq_indent = 3
inp = """
a:
- 1
- [1, 2]
"""
yaml.round_trip(inp)
def test_01(self):
yaml = YAML()
yaml.indent(sequence=6)
yaml.indent(offset=3)
inp = """
a:
- 1
- {b: 3}
"""
yaml.round_trip(inp)
def test_02(self):
yaml = YAML()
yaml.indent(mapping=5, sequence=6, offset=3)
inp = """
a:
b:
- 1
- [1, 2]
"""
yaml.round_trip(inp)
def test_03(self):
inp = """
a:
b:
c:
- 1
- [1, 2]
"""
round_trip(inp, indent=4)
def test_04(self):
yaml = YAML()
yaml.indent(mapping=5, sequence=6)
inp = """
a:
b:
- 1
- [1, 2]
- {d: 3.14}
"""
yaml.round_trip(inp)
def test_issue_51(self):
yaml = YAML()
# yaml.map_indent = 2 # the default
yaml.indent(sequence=4, offset=2)
yaml.preserve_quotes = True
yaml.round_trip(
"""
role::startup::author::rsyslog_inputs:
imfile:
- ruleset: 'AEM-slinglog'
File: '/opt/aem/author/crx-quickstart/logs/error.log'
startmsg.regex: '^[-+T.:[:digit:]]*'
tag: 'error'
- ruleset: 'AEM-slinglog'
File: '/opt/aem/author/crx-quickstart/logs/stdout.log'
startmsg.regex: '^[-+T.:[:digit:]]*'
tag: 'stdout'
"""
)
# ############ indentation
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_int.py 0000775 0000000 0000000 00000001613 14742310675 0023544 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division, unicode_literals
import pytest # NOQA
from .roundtrip import dedent, round_trip_load, round_trip_dump
# http://yaml.org/type/int.html is where underscores in integers are defined
class TestBinHexOct:
def test_calculate(self):
# make sure type, leading zero(s) and underscore are preserved
s = dedent(
"""\
- 42
- 0b101010
- 0x_2a
- 0x2A
- 0o00_52
"""
)
d = round_trip_load(s)
for idx, elem in enumerate(d):
elem -= 21
d[idx] = elem
for idx, elem in enumerate(d):
elem *= 2
d[idx] = elem
for idx, elem in enumerate(d):
t = elem
elem **= 2
elem //= t
d[idx] = elem
assert round_trip_dump(d) == s
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_issues.py 0000775 0000000 0000000 00000056241 14742310675 0024274 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import absolute_import, print_function, unicode_literals
import pytest # NOQA
import sys
from .roundtrip import (
round_trip,
na_round_trip,
round_trip_load,
round_trip_dump,
dedent,
save_and_run,
YAML,
) # NOQA
class TestIssues:
def test_issue_61(self):
import srsly.ruamel_yaml
s = dedent(
"""
def1: &ANCHOR1
key1: value1
def: &ANCHOR
<<: *ANCHOR1
key: value
comb:
<<: *ANCHOR
"""
)
data = srsly.ruamel_yaml.round_trip_load(s)
assert str(data["comb"]) == str(data["def"])
if sys.version_info >= (3, 12):
assert (
str(data["comb"]) == "ordereddict({'key': 'value', 'key1': 'value1'})"
)
else:
assert (
str(data["comb"]) == "ordereddict([('key', 'value'), ('key1', 'value1')])"
)
def test_issue_82(self, tmpdir):
program_src = r'''
from __future__ import print_function
import srsly.ruamel_yaml as yaml
import re
class SINumber(yaml.YAMLObject):
PREFIXES = {'k': 1e3, 'M': 1e6, 'G': 1e9}
yaml_loader = yaml.Loader
yaml_dumper = yaml.Dumper
yaml_tag = u'!si'
yaml_implicit_pattern = re.compile(
r'^(?P[0-9]+(?:\.[0-9]+)?)(?P[kMG])$')
@classmethod
def from_yaml(cls, loader, node):
return cls(node.value)
@classmethod
def to_yaml(cls, dumper, data):
return dumper.represent_scalar(cls.yaml_tag, str(data))
def __init__(self, *args):
m = self.yaml_implicit_pattern.match(args[0])
self.value = float(m.groupdict()['value'])
self.prefix = m.groupdict()['prefix']
def __str__(self):
return str(self.value)+self.prefix
def __int__(self):
return int(self.value*self.PREFIXES[self.prefix])
# This fails:
yaml.add_implicit_resolver(SINumber.yaml_tag, SINumber.yaml_implicit_pattern)
ret = yaml.load("""
[1,2,3, !si 10k, 100G]
""", Loader=yaml.Loader)
for idx, l in enumerate([1, 2, 3, 10000, 100000000000]):
assert int(ret[idx]) == l
'''
assert save_and_run(dedent(program_src), tmpdir) == 1
def test_issue_82rt(self, tmpdir):
yaml_str = "[1, 2, 3, !si 10k, 100G]\n"
x = round_trip(yaml_str, preserve_quotes=True) # NOQA
def test_issue_102(self):
yaml_str = dedent(
"""
var1: #empty
var2: something #notempty
var3: {} #empty object
var4: {a: 1} #filled object
var5: [] #empty array
"""
)
x = round_trip(yaml_str, preserve_quotes=True) # NOQA
def test_issue_150(self):
from srsly.ruamel_yaml import YAML
inp = """\
base: &base_key
first: 123
second: 234
child:
<<: *base_key
third: 345
"""
yaml = YAML()
data = yaml.load(inp)
child = data["child"]
assert "second" in dict(**child)
def test_issue_160(self):
from srsly.ruamel_yaml.compat import StringIO
s = dedent(
"""\
root:
# a comment
- {some_key: "value"}
foo: 32
bar: 32
"""
)
a = round_trip_load(s)
del a["root"][0]["some_key"]
buf = StringIO()
round_trip_dump(a, buf, block_seq_indent=4)
exp = dedent(
"""\
root:
# a comment
- {}
foo: 32
bar: 32
"""
)
assert buf.getvalue() == exp
def test_issue_161(self):
yaml_str = dedent(
"""\
mapping-A:
key-A:{}
mapping-B:
"""
)
for comment in ["", " # no-newline", " # some comment\n", "\n"]:
s = yaml_str.format(comment)
res = round_trip(s) # NOQA
def test_issue_161a(self):
yaml_str = dedent(
"""\
mapping-A:
key-A:{}
mapping-B:
"""
)
for comment in ["\n# between"]:
s = yaml_str.format(comment)
res = round_trip(s) # NOQA
def test_issue_163(self):
s = dedent(
"""\
some-list:
# List comment
- {}
"""
)
x = round_trip(s, preserve_quotes=True) # NOQA
json_str = (
r'{"sshKeys":[{"name":"AETROS\/google-k80-1","uses":0,"getLastUse":0,'
'"fingerprint":"MD5:19:dd:41:93:a1:a3:f5:91:4a:8e:9b:d0:ae:ce:66:4c",'
'"created":1509497961}]}'
)
json_str2 = '{"abc":[{"a":"1", "uses":0}]}'
def test_issue_172(self):
x = round_trip_load(TestIssues.json_str2) # NOQA
x = round_trip_load(TestIssues.json_str) # NOQA
def test_issue_176(self):
# basic request by Stuart Berg
from srsly.ruamel_yaml import YAML
yaml = YAML()
seq = yaml.load("[1,2,3]")
seq[:] = [1, 2, 3, 4]
def test_issue_176_preserve_comments_on_extended_slice_assignment(self):
yaml_str = dedent(
"""\
- a
- b # comment
- c # commment c
# comment c+
- d
- e # comment
"""
)
seq = round_trip_load(yaml_str)
seq[1::2] = ["B", "D"]
res = round_trip_dump(seq)
assert res == yaml_str.replace(" b ", " B ").replace(" d\n", " D\n")
def test_issue_176_test_slicing(self):
from srsly.ruamel_yaml.compat import PY2
mss = round_trip_load("[0, 1, 2, 3, 4]")
assert len(mss) == 5
assert mss[2:2] == []
assert mss[2:4] == [2, 3]
assert mss[1::2] == [1, 3]
# slice assignment
m = mss[:]
m[2:2] = [42]
assert m == [0, 1, 42, 2, 3, 4]
m = mss[:]
m[:3] = [42, 43, 44]
assert m == [42, 43, 44, 3, 4]
m = mss[:]
m[2:] = [42, 43, 44]
assert m == [0, 1, 42, 43, 44]
m = mss[:]
m[:] = [42, 43, 44]
assert m == [42, 43, 44]
# extend slice assignment
m = mss[:]
m[2:4] = [42, 43, 44]
assert m == [0, 1, 42, 43, 44, 4]
m = mss[:]
m[1::2] = [42, 43]
assert m == [0, 42, 2, 43, 4]
m = mss[:]
if PY2:
with pytest.raises(ValueError, match="attempt to assign"):
m[1::2] = [42, 43, 44]
else:
with pytest.raises(TypeError, match="too many"):
m[1::2] = [42, 43, 44]
if PY2:
with pytest.raises(ValueError, match="attempt to assign"):
m[1::2] = [42]
else:
with pytest.raises(TypeError, match="not enough"):
m[1::2] = [42]
m = mss[:]
m += [5]
m[1::2] = [42, 43, 44]
assert m == [0, 42, 2, 43, 4, 44]
# deleting
m = mss[:]
del m[1:3]
assert m == [0, 3, 4]
m = mss[:]
del m[::2]
assert m == [1, 3]
m = mss[:]
del m[:]
assert m == []
def test_issue_184(self):
yaml_str = dedent(
"""\
test::test:
# test
foo:
bar: baz
"""
)
d = round_trip_load(yaml_str)
d["bar"] = "foo"
d.yaml_add_eol_comment("test1", "bar")
assert round_trip_dump(d) == yaml_str + "bar: foo # test1\n"
def test_issue_219(self):
yaml_str = dedent(
"""\
[StackName: AWS::StackName]
"""
)
d = round_trip_load(yaml_str) # NOQA
def test_issue_219a(self):
yaml_str = dedent(
"""\
[StackName:
AWS::StackName]
"""
)
d = round_trip_load(yaml_str) # NOQA
def test_issue_220(self, tmpdir):
program_src = r'''
from srsly.ruamel_yaml import YAML
yaml_str = u"""\
---
foo: ["bar"]
"""
yaml = YAML(typ='safe', pure=True)
d = yaml.load(yaml_str)
print(d)
'''
assert save_and_run(dedent(program_src), tmpdir, optimized=True) == 0
def test_issue_221_add(self):
from srsly.ruamel_yaml.comments import CommentedSeq
a = CommentedSeq([1, 2, 3])
a + [4, 5]
def test_issue_221_sort(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.compat import StringIO
yaml = YAML()
inp = dedent(
"""\
- d
- a # 1
- c # 3
- e # 5
- b # 2
"""
)
a = yaml.load(dedent(inp))
a.sort()
buf = StringIO()
yaml.dump(a, buf)
exp = dedent(
"""\
- a # 1
- b # 2
- c # 3
- d
- e # 5
"""
)
assert buf.getvalue() == exp
def test_issue_221_sort_reverse(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.compat import StringIO
yaml = YAML()
inp = dedent(
"""\
- d
- a # 1
- c # 3
- e # 5
- b # 2
"""
)
a = yaml.load(dedent(inp))
a.sort(reverse=True)
buf = StringIO()
yaml.dump(a, buf)
exp = dedent(
"""\
- e # 5
- d
- c # 3
- b # 2
- a # 1
"""
)
assert buf.getvalue() == exp
def test_issue_221_sort_key(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.compat import StringIO
yaml = YAML()
inp = dedent(
"""\
- four
- One # 1
- Three # 3
- five # 5
- two # 2
"""
)
a = yaml.load(dedent(inp))
a.sort(key=str.lower)
buf = StringIO()
yaml.dump(a, buf)
exp = dedent(
"""\
- five # 5
- four
- One # 1
- Three # 3
- two # 2
"""
)
assert buf.getvalue() == exp
def test_issue_221_sort_key_reverse(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.compat import StringIO
yaml = YAML()
inp = dedent(
"""\
- four
- One # 1
- Three # 3
- five # 5
- two # 2
"""
)
a = yaml.load(dedent(inp))
a.sort(key=str.lower, reverse=True)
buf = StringIO()
yaml.dump(a, buf)
exp = dedent(
"""\
- two # 2
- Three # 3
- One # 1
- four
- five # 5
"""
)
assert buf.getvalue() == exp
def test_issue_222(self):
import srsly.ruamel_yaml
from srsly.ruamel_yaml.compat import StringIO
buf = StringIO()
srsly.ruamel_yaml.safe_dump(["012923"], buf)
assert buf.getvalue() == "['012923']\n"
def test_issue_223(self):
import srsly.ruamel_yaml
yaml = srsly.ruamel_yaml.YAML(typ="safe")
yaml.load("phone: 0123456789")
def test_issue_232(self):
import srsly.ruamel_yaml
import srsly.ruamel_yaml as yaml
with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
yaml.safe_load("]")
with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
yaml.safe_load("{]")
def test_issue_233(self):
from srsly.ruamel_yaml import YAML
import json
yaml = YAML()
data = yaml.load("{}")
json_str = json.dumps(data) # NOQA
def test_issue_233a(self):
from srsly.ruamel_yaml import YAML
import json
yaml = YAML()
data = yaml.load("[]")
json_str = json.dumps(data) # NOQA
def test_issue_234(self):
from srsly.ruamel_yaml import YAML
inp = dedent(
"""\
- key: key1
ctx: [one, two]
help: one
cmd: >
foo bar
foo bar
"""
)
yaml = YAML(typ="safe", pure=True)
data = yaml.load(inp)
fold = data[0]["cmd"]
print(repr(fold))
assert "\a" not in fold
def test_issue_236(self):
inp = """
conf:
xx: {a: "b", c: []}
asd: "nn"
"""
d = round_trip(inp, preserve_quotes=True) # NOQA
def test_issue_238(self, tmpdir):
program_src = r"""
import srsly.ruamel_yaml
from srsly.ruamel_yaml.compat import StringIO
yaml = srsly.ruamel_yaml.YAML(typ='unsafe')
class A:
def __setstate__(self, d):
self.__dict__ = d
class B:
pass
a = A()
b = B()
a.x = b
b.y = [b]
assert a.x.y[0] == a.x
buf = StringIO()
yaml.dump(a, buf)
data = yaml.load(buf.getvalue())
assert data.x.y[0] == data.x
"""
assert save_and_run(dedent(program_src), tmpdir) == 1
def test_issue_239(self):
inp = """
first_name: Art
occupation: Architect
# I'm safe
about: Art Vandelay is a fictional character that George invents...
# we are not :(
# help me!
---
# what?!
hello: world
# someone call the Batman
foo: bar # or quz
# Lost again
---
I: knew
# final words
"""
d = YAML().round_trip_all(inp) # NOQA
def test_issue_242(self):
from srsly.ruamel_yaml.comments import CommentedMap
d0 = CommentedMap([("a", "b")])
assert d0["a"] == "b"
def test_issue_245(self):
from srsly.ruamel_yaml import YAML
inp = """
d: yes
"""
for typ in ["safepure", "rt", "safe"]:
if typ.endswith("pure"):
pure = True
typ = typ[:-4]
else:
pure = None
yaml = YAML(typ=typ, pure=pure)
yaml.version = (1, 1)
d = yaml.load(inp)
print(typ, yaml.parser, yaml.resolver)
assert d["d"] is True
def test_issue_249(self):
yaml = YAML()
inp = dedent(
"""\
# comment
-
- 1
- 2
- 3
"""
)
exp = dedent(
"""\
# comment
- - 1
- 2
- 3
"""
)
yaml.round_trip(inp, outp=exp) # NOQA
def test_issue_250(self):
inp = """
# 1.
- - 1
# 2.
- map: 2
# 3.
- 4
"""
d = round_trip(inp) # NOQA
# @pytest.mark.xfail(strict=True, reason='bla bla', raises=AssertionError)
def test_issue_279(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.compat import StringIO
yaml = YAML()
yaml.indent(sequence=4, offset=2)
inp = dedent(
"""\
experiments:
- datasets:
# ATLAS EWK
- {dataset: ATLASWZRAP36PB, frac: 1.0}
- {dataset: ATLASZHIGHMASS49FB, frac: 1.0}
"""
)
a = yaml.load(inp)
buf = StringIO()
yaml.dump(a, buf)
print(buf.getvalue())
assert buf.getvalue() == inp
def test_issue_280(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.representer import RepresenterError
from collections import namedtuple
from sys import stdout
T = namedtuple("T", ("a", "b"))
t = T(1, 2)
yaml = YAML()
with pytest.raises(RepresenterError, match="cannot represent"):
yaml.dump({"t": t}, stdout)
def test_issue_282(self):
# update from list of tuples caused AttributeError
import srsly.ruamel_yaml
yaml_data = srsly.ruamel_yaml.comments.CommentedMap(
[("a", "apple"), ("b", "banana")]
)
yaml_data.update([("c", "cantaloupe")])
yaml_data.update({"d": "date", "k": "kiwi"})
assert "c" in yaml_data.keys()
assert "c" in yaml_data._ok
def test_issue_284(self):
import srsly.ruamel_yaml
inp = dedent(
"""\
plain key: in-line value
: # Both empty
"quoted key":
- entry
"""
)
yaml = srsly.ruamel_yaml.YAML(typ="rt")
yaml.version = (1, 2)
d = yaml.load(inp)
assert d[None] is None
yaml = srsly.ruamel_yaml.YAML(typ="rt")
yaml.version = (1, 1)
with pytest.raises(
srsly.ruamel_yaml.parser.ParserError, match="expected "
):
d = yaml.load(inp)
def test_issue_285(self):
from srsly.ruamel_yaml import YAML
yaml = YAML()
inp = dedent(
"""\
%YAML 1.1
---
- y
- n
- Y
- N
"""
)
a = yaml.load(inp)
assert a[0]
assert a[2]
assert not a[1]
assert not a[3]
def test_issue_286(self):
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.compat import StringIO
yaml = YAML()
inp = dedent(
"""\
parent_key:
- sub_key: sub_value
# xxx"""
)
a = yaml.load(inp)
a["new_key"] = "new_value"
buf = StringIO()
yaml.dump(a, buf)
assert buf.getvalue().endswith("xxx\nnew_key: new_value\n")
def test_issue_288(self):
import sys
from srsly.ruamel_yaml.compat import StringIO
from srsly.ruamel_yaml import YAML
yamldoc = dedent(
"""\
---
# Reusable values
aliases:
# First-element comment
- &firstEntry First entry
# Second-element comment
- &secondEntry Second entry
# Third-element comment is
# a multi-line value
- &thirdEntry Third entry
# EOF Comment
"""
)
yaml = YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.explicit_start = True
yaml.preserve_quotes = True
yaml.width = sys.maxsize
data = yaml.load(yamldoc)
buf = StringIO()
yaml.dump(data, buf)
assert buf.getvalue() == yamldoc
def test_issue_288a(self):
import sys
from srsly.ruamel_yaml.compat import StringIO
from srsly.ruamel_yaml import YAML
yamldoc = dedent(
"""\
---
# Reusable values
aliases:
# First-element comment
- &firstEntry First entry
# Second-element comment
- &secondEntry Second entry
# Third-element comment is
# a multi-line value
- &thirdEntry Third entry
# EOF Comment
"""
)
yaml = YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.explicit_start = True
yaml.preserve_quotes = True
yaml.width = sys.maxsize
data = yaml.load(yamldoc)
buf = StringIO()
yaml.dump(data, buf)
assert buf.getvalue() == yamldoc
def test_issue_290(self):
import sys
from srsly.ruamel_yaml.compat import StringIO
from srsly.ruamel_yaml import YAML
yamldoc = dedent(
"""\
---
aliases:
# Folded-element comment
# for a multi-line value
- &FoldedEntry >
THIS IS A
FOLDED, MULTI-LINE
VALUE
# Literal-element comment
# for a multi-line value
- &literalEntry |
THIS IS A
LITERAL, MULTI-LINE
VALUE
# Plain-element comment
- &plainEntry Plain entry
"""
)
yaml = YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.explicit_start = True
yaml.preserve_quotes = True
yaml.width = sys.maxsize
data = yaml.load(yamldoc)
buf = StringIO()
yaml.dump(data, buf)
assert buf.getvalue() == yamldoc
def test_issue_290a(self):
import sys
from srsly.ruamel_yaml.compat import StringIO
from srsly.ruamel_yaml import YAML
yamldoc = dedent(
"""\
---
aliases:
# Folded-element comment
# for a multi-line value
- &FoldedEntry >
THIS IS A
FOLDED, MULTI-LINE
VALUE
# Literal-element comment
# for a multi-line value
- &literalEntry |
THIS IS A
LITERAL, MULTI-LINE
VALUE
# Plain-element comment
- &plainEntry Plain entry
"""
)
yaml = YAML()
yaml.indent(mapping=2, sequence=4, offset=2)
yaml.explicit_start = True
yaml.preserve_quotes = True
yaml.width = sys.maxsize
data = yaml.load(yamldoc)
buf = StringIO()
yaml.dump(data, buf)
assert buf.getvalue() == yamldoc
# @pytest.mark.xfail(strict=True, reason='should fail pre 0.15.100', raises=AssertionError)
def test_issue_295(self):
# deepcopy also makes a copy of the start and end mark, and these did not
# have any comparison beyond their ID, which of course changed, breaking
# some old merge_comment code
import copy
inp = dedent(
"""
A:
b:
# comment
- l1
- l2
C:
d: e
f:
# comment2
- - l31
- l32
- l33: '5'
"""
)
data = round_trip_load(inp) # NOQA
dc = copy.deepcopy(data)
assert round_trip_dump(dc) == inp
def test_issue_300(self):
from srsly.ruamel_yaml import YAML
inp = dedent(
"""
%YAML 1.2
%TAG ! tag:example.com,2019/path#fragment
---
null
"""
)
YAML().load(inp)
def test_issue_300a(self):
import srsly.ruamel_yaml
inp = dedent(
"""
%YAML 1.1
%TAG ! tag:example.com,2019/path#fragment
---
null
"""
)
yaml = YAML()
with pytest.raises(
srsly.ruamel_yaml.scanner.ScannerError, match="while scanning a directive"
):
yaml.load(inp)
def test_issue_304(self):
inp = """
%YAML 1.2
%TAG ! tag:example.com,2019:
---
!foo null
...
"""
d = na_round_trip(inp) # NOQA
def test_issue_305(self):
inp = """
%YAML 1.2
---
! null
...
"""
d = na_round_trip(inp) # NOQA
def test_issue_307(self):
inp = """
%YAML 1.2
%TAG ! tag:example.com,2019/path#
---
null
...
"""
d = na_round_trip(inp) # NOQA
# @pytest.mark.xfail(strict=True, reason='bla bla', raises=AssertionError)
# def test_issue_ xxx(self):
# inp = """
# """
# d = round_trip(inp) # NOQA
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_json_numbers.py 0000775 0000000 0000000 00000002654 14742310675 0025464 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
import pytest # NOQA
import json
def load(s, typ=float):
import srsly.ruamel_yaml
x = '{"low": %s }' % (s)
print("input: [%s]" % (s), repr(x))
# just to check it is loadable json
res = json.loads(x)
assert isinstance(res["low"], typ)
ret_val = srsly.ruamel_yaml.load(x, srsly.ruamel_yaml.RoundTripLoader)
print(ret_val)
return ret_val["low"]
class TestJSONNumbers:
# based on http://stackoverflow.com/a/30462009/1307905
# yaml number regex: http://yaml.org/spec/1.2/spec.html#id2804092
#
# -? [1-9] ( \. [0-9]* [1-9] )? ( e [-+] [1-9] [0-9]* )?
#
# which is not a superset of the JSON numbers
def test_json_number_float(self):
for x in (
y.split("#")[0].strip()
for y in """
1.0 # should fail on YAML spec on 1-9 allowed as single digit
-1.0
1e-06
3.1e-5
3.1e+5
3.1e5 # should fail on YAML spec: no +- after e
""".splitlines()
):
if not x:
continue
res = load(x)
assert isinstance(res, float)
def test_json_number_int(self):
for x in (
y.split("#")[0].strip()
for y in """
42
""".splitlines()
):
if not x:
continue
res = load(x, int)
assert isinstance(res, int)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_line_col.py 0000775 0000000 0000000 00000004072 14742310675 0024540 0 ustar 00root root 0000000 0000000 # coding: utf-8
import pytest # NOQA
from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
def load(s):
return round_trip_load(dedent(s))
class TestLineCol:
def test_item_00(self):
data = load(
"""
- a
- e
- [b, d]
- c
"""
)
assert data[2].lc.line == 2
assert data[2].lc.col == 2
def test_item_01(self):
data = load(
"""
- a
- e
- {x: 3}
- c
"""
)
assert data[2].lc.line == 2
assert data[2].lc.col == 2
def test_item_02(self):
data = load(
"""
- a
- e
- !!set {x, y}
- c
"""
)
assert data[2].lc.line == 2
assert data[2].lc.col == 2
def test_item_03(self):
data = load(
"""
- a
- e
- !!omap
- x: 1
- y: 3
- c
"""
)
assert data[2].lc.line == 2
assert data[2].lc.col == 2
def test_item_04(self):
data = load(
"""
# testing line and column based on SO
# http://stackoverflow.com/questions/13319067/
- key1: item 1
key2: item 2
- key3: another item 1
key4: another item 2
"""
)
assert data[0].lc.line == 2
assert data[0].lc.col == 2
assert data[1].lc.line == 4
assert data[1].lc.col == 2
def test_pos_mapping(self):
data = load(
"""
a: 1
b: 2
c: 3
# comment
klm: 42
d: 4
"""
)
assert data.lc.key("klm") == (4, 0)
assert data.lc.value("klm") == (4, 5)
def test_pos_sequence(self):
data = load(
"""
- a
- b
- c
# next one!
- klm
- d
"""
)
assert data.lc.item(3) == (4, 2)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_literal.py 0000775 0000000 0000000 00000017153 14742310675 0024414 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
import pytest # NOQA
from .roundtrip import YAML # does an automatic dedent on load
"""
YAML 1.0 allowed root level literal style without indentation:
"Usually top level nodes are not indented" (example 4.21 in 4.6.3)
YAML 1.1 is a bit vague but says:
"Regardless of style, scalar content must always be indented by at least one space"
(4.4.3)
"In general, the document’s node is indented as if it has a parent indented at -1 spaces."
(4.3.3)
YAML 1.2 is again clear about root literal level scalar after directive in example 9.5:
%YAML 1.2
--- |
%!PS-Adobe-2.0
...
%YAML1.2
---
# Empty
...
"""
class TestNoIndent:
def test_root_literal_scalar_indent_example_9_5(self):
yaml = YAML()
s = "%!PS-Adobe-2.0"
inp = """
--- |
{}
"""
d = yaml.load(inp.format(s))
print(d)
assert d == s + "\n"
def test_root_literal_scalar_no_indent(self):
yaml = YAML()
s = "testing123"
inp = """
--- |
{}
"""
d = yaml.load(inp.format(s))
print(d)
assert d == s + "\n"
def test_root_literal_scalar_no_indent_1_1(self):
yaml = YAML()
s = "testing123"
inp = """
%YAML 1.1
--- |
{}
"""
d = yaml.load(inp.format(s))
print(d)
assert d == s + "\n"
def test_root_literal_scalar_no_indent_1_1_old_style(self):
from textwrap import dedent
from srsly.ruamel_yaml import safe_load
s = "testing123"
inp = """
%YAML 1.1
--- |
{}
"""
d = safe_load(dedent(inp.format(s)))
print(d)
assert d == s + "\n"
def test_root_literal_scalar_no_indent_1_1_no_raise(self):
# from srsly.ruamel_yaml.parser import ParserError
yaml = YAML()
yaml.root_level_block_style_scalar_no_indent_error_1_1 = True
s = "testing123"
# with pytest.raises(ParserError):
if True:
inp = """
%YAML 1.1
--- |
{}
"""
yaml.load(inp.format(s))
def test_root_literal_scalar_indent_offset_one(self):
yaml = YAML()
s = "testing123"
inp = """
--- |1
{}
"""
d = yaml.load(inp.format(s))
print(d)
assert d == s + "\n"
def test_root_literal_scalar_indent_offset_four(self):
yaml = YAML()
s = "testing123"
inp = """
--- |4
{}
"""
d = yaml.load(inp.format(s))
print(d)
assert d == s + "\n"
def test_root_literal_scalar_indent_offset_two_leading_space(self):
yaml = YAML()
s = " testing123"
inp = """
--- |4
{s}
{s}
"""
d = yaml.load(inp.format(s=s))
print(d)
assert d == (s + "\n") * 2
def test_root_literal_scalar_no_indent_special(self):
yaml = YAML()
s = "%!PS-Adobe-2.0"
inp = """
--- |
{}
"""
d = yaml.load(inp.format(s))
print(d)
assert d == s + "\n"
def test_root_folding_scalar_indent(self):
yaml = YAML()
s = "%!PS-Adobe-2.0"
inp = """
--- >
{}
"""
d = yaml.load(inp.format(s))
print(d)
assert d == s + "\n"
def test_root_folding_scalar_no_indent(self):
yaml = YAML()
s = "testing123"
inp = """
--- >
{}
"""
d = yaml.load(inp.format(s))
print(d)
assert d == s + "\n"
def test_root_folding_scalar_no_indent_special(self):
yaml = YAML()
s = "%!PS-Adobe-2.0"
inp = """
--- >
{}
"""
d = yaml.load(inp.format(s))
print(d)
assert d == s + "\n"
def test_root_literal_multi_doc(self):
yaml = YAML(typ="safe", pure=True)
s1 = "abc"
s2 = "klm"
inp = """
--- |-
{}
--- |
{}
"""
for idx, d1 in enumerate(yaml.load_all(inp.format(s1, s2))):
print("d1:", d1)
assert ["abc", "klm\n"][idx] == d1
def test_root_literal_doc_indent_directives_end(self):
yaml = YAML()
yaml.explicit_start = True
inp = """
--- |-
%YAML 1.3
---
this: is a test
"""
yaml.round_trip(inp)
def test_root_literal_doc_indent_document_end(self):
yaml = YAML()
yaml.explicit_start = True
inp = """
--- |-
some more
...
text
"""
yaml.round_trip(inp)
def test_root_literal_doc_indent_marker(self):
yaml = YAML()
yaml.explicit_start = True
inp = """
--- |2
some more
text
"""
d = yaml.load(inp)
print(type(d), repr(d))
yaml.round_trip(inp)
def test_nested_literal_doc_indent_marker(self):
yaml = YAML()
yaml.explicit_start = True
inp = """
---
a: |2
some more
text
"""
d = yaml.load(inp)
print(type(d), repr(d))
yaml.round_trip(inp)
class Test_RoundTripLiteral:
def test_rt_root_literal_scalar_no_indent(self):
yaml = YAML()
yaml.explicit_start = True
s = "testing123"
ys = """
--- |
{}
"""
ys = ys.format(s)
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_rt_root_literal_scalar_indent(self):
yaml = YAML()
yaml.explicit_start = True
yaml.indent = 4
s = "testing123"
ys = """
--- |
{}
"""
ys = ys.format(s)
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_rt_root_plain_scalar_no_indent(self):
yaml = YAML()
yaml.explicit_start = True
yaml.indent = 0
s = "testing123"
ys = """
---
{}
"""
ys = ys.format(s)
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_rt_root_plain_scalar_expl_indent(self):
yaml = YAML()
yaml.explicit_start = True
yaml.indent = 4
s = "testing123"
ys = """
---
{}
"""
ys = ys.format(s)
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_rt_root_sq_scalar_expl_indent(self):
yaml = YAML()
yaml.explicit_start = True
yaml.indent = 4
s = "'testing: 123'"
ys = """
---
{}
"""
ys = ys.format(s)
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_rt_root_dq_scalar_expl_indent(self):
# if yaml.indent is the default (None)
# then write after the directive indicator
yaml = YAML()
yaml.explicit_start = True
yaml.indent = 0
s = '"\'testing123"'
ys = """
---
{}
"""
ys = ys.format(s)
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_rt_root_literal_scalar_no_indent_no_eol(self):
yaml = YAML()
yaml.explicit_start = True
s = "testing123"
ys = """
--- |-
{}
"""
ys = ys.format(s)
d = yaml.load(ys)
yaml.dump(d, compare=ys)
def test_rt_non_root_literal_scalar(self):
yaml = YAML()
s = "testing123"
ys = """
- |
{}
"""
ys = ys.format(s)
d = yaml.load(ys)
yaml.dump(d, compare=ys)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_none.py 0000775 0000000 0000000 00000002624 14742310675 0023714 0 ustar 00root root 0000000 0000000 # coding: utf-8
import pytest # NOQA
class TestNone:
def test_dump00(self):
import srsly.ruamel_yaml # NOQA
data = None
s = srsly.ruamel_yaml.round_trip_dump(data)
assert s == "null\n...\n"
d = srsly.ruamel_yaml.round_trip_load(s)
assert d == data
def test_dump01(self):
import srsly.ruamel_yaml # NOQA
data = None
s = srsly.ruamel_yaml.round_trip_dump(data, explicit_end=True)
assert s == "null\n...\n"
d = srsly.ruamel_yaml.round_trip_load(s)
assert d == data
def test_dump02(self):
import srsly.ruamel_yaml # NOQA
data = None
s = srsly.ruamel_yaml.round_trip_dump(data, explicit_end=False)
assert s == "null\n...\n"
d = srsly.ruamel_yaml.round_trip_load(s)
assert d == data
def test_dump03(self):
import srsly.ruamel_yaml # NOQA
data = None
s = srsly.ruamel_yaml.round_trip_dump(data, explicit_start=True)
assert s == "---\n...\n"
d = srsly.ruamel_yaml.round_trip_load(s)
assert d == data
def test_dump04(self):
import srsly.ruamel_yaml # NOQA
data = None
s = srsly.ruamel_yaml.round_trip_dump(
data, explicit_start=True, explicit_end=False
)
assert s == "---\n...\n"
d = srsly.ruamel_yaml.round_trip_load(s)
assert d == data
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_numpy.py 0000775 0000000 0000000 00000000750 14742310675 0024123 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, absolute_import, division, unicode_literals
try:
import numpy
except: # NOQA
numpy = None
def Xtest_numpy():
import srsly.ruamel_yaml
if numpy is None:
return
data = numpy.arange(10)
print("data", type(data), data)
yaml_str = srsly.ruamel_yaml.dump(data)
datb = srsly.ruamel_yaml.load(yaml_str)
print("datb", type(datb), datb)
print("\nYAML", yaml_str)
assert data == datb
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_program_config.py 0000775 0000000 0000000 00000003515 14742310675 0025751 0 ustar 00root root 0000000 0000000 import pytest # NOQA
# import srsly.ruamel_yaml
from .roundtrip import round_trip
class TestProgramConfig:
def test_application_arguments(self):
# application configur
round_trip(
"""
args:
username: anthon
passwd: secret
fullname: Anthon van der Neut
tmux:
session-name: test
loop:
wait: 10
"""
)
def test_single(self):
# application configuration
round_trip(
"""
# default arguments for the program
args: # needed to prevent comment wrapping
# this should be your username
username: anthon
passwd: secret # this is plaintext don't reuse \
# important/system passwords
fullname: Anthon van der Neut
tmux:
session-name: test # make sure this doesn't clash with
# other sessions
loop: # looping related defaults
# experiment with the following
wait: 10
# no more argument info to pass
"""
)
def test_multi(self):
# application configuration
round_trip(
"""
# default arguments for the program
args: # needed to prevent comment wrapping
# this should be your username
username: anthon
passwd: secret # this is plaintext don't reuse
# important/system passwords
fullname: Anthon van der Neut
tmux:
session-name: test # make sure this doesn't clash with
# other sessions
loop: # looping related defaults
# experiment with the following
wait: 10
# no more argument info to pass
"""
)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_spec_examples.py 0000775 0000000 0000000 00000013370 14742310675 0025605 0 ustar 00root root 0000000 0000000 from .roundtrip import YAML
import pytest # NOQA
def test_example_2_1():
yaml = YAML()
yaml.round_trip(
"""
- Mark McGwire
- Sammy Sosa
- Ken Griffey
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_2():
yaml = YAML()
yaml.mapping_value_align = True
yaml.round_trip(
"""
hr: 65 # Home runs
avg: 0.278 # Batting average
rbi: 147 # Runs Batted In
"""
)
def test_example_2_3():
yaml = YAML()
yaml.indent(sequence=4, offset=2)
yaml.round_trip(
"""
american:
- Boston Red Sox
- Detroit Tigers
- New York Yankees
national:
- New York Mets
- Chicago Cubs
- Atlanta Braves
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_4():
yaml = YAML()
yaml.mapping_value_align = True
yaml.round_trip(
"""
-
name: Mark McGwire
hr: 65
avg: 0.278
-
name: Sammy Sosa
hr: 63
avg: 0.288
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_5():
yaml = YAML()
yaml.flow_sequence_element_align = True
yaml.round_trip(
"""
- [name , hr, avg ]
- [Mark McGwire, 65, 0.278]
- [Sammy Sosa , 63, 0.288]
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_6():
yaml = YAML()
# yaml.flow_mapping_final_comma = False
yaml.flow_mapping_one_element_per_line = True
yaml.round_trip(
"""
Mark McGwire: {hr: 65, avg: 0.278}
Sammy Sosa: {
hr: 63,
avg: 0.288
}
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_7():
yaml = YAML()
yaml.round_trip_all(
"""
# Ranking of 1998 home runs
---
- Mark McGwire
- Sammy Sosa
- Ken Griffey
# Team ranking
---
- Chicago Cubs
- St Louis Cardinals
"""
)
def test_example_2_8():
yaml = YAML()
yaml.explicit_start = True
yaml.explicit_end = True
yaml.round_trip_all(
"""
---
time: 20:03:20
player: Sammy Sosa
action: strike (miss)
...
---
time: 20:03:47
player: Sammy Sosa
action: grand slam
...
"""
)
def test_example_2_9():
yaml = YAML()
yaml.explicit_start = True
yaml.indent(sequence=4, offset=2)
yaml.round_trip(
"""
---
hr: # 1998 hr ranking
- Mark McGwire
- Sammy Sosa
rbi:
# 1998 rbi ranking
- Sammy Sosa
- Ken Griffey
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_10():
yaml = YAML()
yaml.explicit_start = True
yaml.indent(sequence=4, offset=2)
yaml.round_trip(
"""
---
hr:
- Mark McGwire
# Following node labeled SS
- &SS Sammy Sosa
rbi:
- *SS # Subsequent occurrence
- Ken Griffey
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_11():
yaml = YAML()
yaml.round_trip(
"""
? - Detroit Tigers
- Chicago cubs
:
- 2001-07-23
? [ New York Yankees,
Atlanta Braves ]
: [ 2001-07-02, 2001-08-12,
2001-08-14 ]
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_12():
yaml = YAML()
yaml.explicit_start = True
yaml.round_trip(
"""
---
# Products purchased
- item : Super Hoop
quantity: 1
- item : Basketball
quantity: 4
- item : Big Shoes
quantity: 1
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_13():
yaml = YAML()
yaml.round_trip(
r"""
# ASCII Art
--- |
\//||\/||
// || ||__
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_14():
yaml = YAML()
yaml.explicit_start = True
yaml.indent(root_scalar=2) # needs to be added
yaml.round_trip(
"""
--- >
Mark McGwire's
year was crippled
by a knee injury.
"""
)
@pytest.mark.xfail(strict=True)
def test_example_2_15():
yaml = YAML()
yaml.round_trip(
"""
>
Sammy Sosa completed another
fine season with great stats.
63 Home Runs
0.288 Batting Average
What a year!
"""
)
def test_example_2_16():
yaml = YAML()
yaml.round_trip(
"""
name: Mark McGwire
accomplishment: >
Mark set a major league
home run record in 1998.
stats: |
65 Home Runs
0.278 Batting Average
"""
)
@pytest.mark.xfail(
strict=True, reason="cannot YAML dump escape sequences (\n) as hex and normal"
)
def test_example_2_17():
yaml = YAML()
yaml.allow_unicode = False
yaml.preserve_quotes = True
yaml.round_trip(
r"""
unicode: "Sosa did fine.\u263A"
control: "\b1998\t1999\t2000\n"
hex esc: "\x0d\x0a is \r\n"
single: '"Howdy!" he cried.'
quoted: ' # Not a ''comment''.'
tie-fighter: '|\-*-/|'
"""
)
@pytest.mark.xfail(
strict=True, reason="non-literal/folding multiline scalars not supported"
)
def test_example_2_18():
yaml = YAML()
yaml.round_trip(
"""
plain:
This unquoted scalar
spans many lines.
quoted: "So does this
quoted scalar.\n"
"""
)
@pytest.mark.xfail(strict=True, reason="leading + on decimal dropped")
def test_example_2_19():
yaml = YAML()
yaml.round_trip(
"""
canonical: 12345
decimal: +12345
octal: 0o14
hexadecimal: 0xC
"""
)
@pytest.mark.xfail(strict=True, reason="case of NaN not preserved")
def test_example_2_20():
yaml = YAML()
yaml.round_trip(
"""
canonical: 1.23015e+3
exponential: 12.3015e+02
fixed: 1230.15
negative infinity: -.inf
not a number: .NaN
"""
)
def Xtest_example_2_X():
yaml = YAML()
yaml.round_trip(
"""
"""
)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_string.py 0000775 0000000 0000000 00000012743 14742310675 0024266 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
"""
various test cases for string scalars in YAML files
'|' for preserved newlines
'>' for folded (newlines become spaces)
and the chomping modifiers:
'-' for stripping: final line break and any trailing empty lines are excluded
'+' for keeping: final line break and empty lines are preserved
'' for clipping: final line break preserved, empty lines at end not
included in content (no modifier)
"""
import pytest
import platform
# from srsly.ruamel_yaml.compat import ordereddict
from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
class TestLiteralScalarString:
def test_basic_string(self):
round_trip(
"""
a: abcdefg
"""
)
def test_quoted_integer_string(self):
round_trip(
"""
a: '12345'
"""
)
@pytest.mark.skipif(
platform.python_implementation() == "Jython",
reason="Jython throws RepresenterError",
)
def test_preserve_string(self):
inp = """
a: |
abc
def
"""
round_trip(inp, intermediate=dict(a="abc\ndef\n"))
@pytest.mark.skipif(
platform.python_implementation() == "Jython",
reason="Jython throws RepresenterError",
)
def test_preserve_string_strip(self):
s = """
a: |-
abc
def
"""
round_trip(s, intermediate=dict(a="abc\ndef"))
@pytest.mark.skipif(
platform.python_implementation() == "Jython",
reason="Jython throws RepresenterError",
)
def test_preserve_string_keep(self):
# with pytest.raises(AssertionError) as excinfo:
inp = """
a: |+
ghi
jkl
b: x
"""
round_trip(inp, intermediate=dict(a="ghi\njkl\n\n\n", b="x"))
@pytest.mark.skipif(
platform.python_implementation() == "Jython",
reason="Jython throws RepresenterError",
)
def test_preserve_string_keep_at_end(self):
# at EOF you have to specify the ... to get proper "closure"
# of the multiline scalar
inp = """
a: |+
ghi
jkl
...
"""
round_trip(inp, intermediate=dict(a="ghi\njkl\n\n"))
def test_fold_string(self):
inp = """
a: >
abc
def
"""
round_trip(inp)
def test_fold_string_strip(self):
inp = """
a: >-
abc
def
"""
round_trip(inp)
def test_fold_string_keep(self):
with pytest.raises(AssertionError) as excinfo: # NOQA
inp = """
a: >+
abc
def
"""
round_trip(inp, intermediate=dict(a="abc def\n\n"))
class TestQuotedScalarString:
def test_single_quoted_string(self):
inp = """
a: 'abc'
"""
round_trip(inp, preserve_quotes=True)
def test_double_quoted_string(self):
inp = """
a: "abc"
"""
round_trip(inp, preserve_quotes=True)
def test_non_preserved_double_quoted_string(self):
inp = """
a: "abc"
"""
exp = """
a: abc
"""
round_trip(inp, outp=exp)
class TestReplace:
"""inspired by issue 110 from sandres23"""
def test_replace_preserved_scalar_string(self):
import srsly
s = dedent(
"""\
foo: |
foo
foo
bar
foo
"""
)
data = round_trip_load(s, preserve_quotes=True)
so = data["foo"].replace("foo", "bar", 2)
assert isinstance(so, srsly.ruamel_yaml.scalarstring.LiteralScalarString)
assert so == dedent(
"""
bar
bar
bar
foo
"""
)
def test_replace_double_quoted_scalar_string(self):
import srsly
s = dedent(
"""\
foo: "foo foo bar foo"
"""
)
data = round_trip_load(s, preserve_quotes=True)
so = data["foo"].replace("foo", "bar", 2)
assert isinstance(so, srsly.ruamel_yaml.scalarstring.DoubleQuotedScalarString)
assert so == "bar bar bar foo"
class TestWalkTree:
def test_basic(self):
from srsly.ruamel_yaml.comments import CommentedMap
from srsly.ruamel_yaml.scalarstring import walk_tree
data = CommentedMap()
data[1] = "a"
data[2] = "with\nnewline\n"
walk_tree(data)
exp = """\
1: a
2: |
with
newline
"""
assert round_trip_dump(data) == dedent(exp)
def test_map(self):
from srsly.ruamel_yaml.compat import ordereddict
from srsly.ruamel_yaml.comments import CommentedMap
from srsly.ruamel_yaml.scalarstring import walk_tree, preserve_literal
from srsly.ruamel_yaml.scalarstring import DoubleQuotedScalarString as dq
from srsly.ruamel_yaml.scalarstring import SingleQuotedScalarString as sq
data = CommentedMap()
data[1] = "a"
data[2] = "with\nnew : line\n"
data[3] = "${abc}"
data[4] = "almost:mapping"
m = ordereddict([("\n", preserve_literal), ("${", sq), (":", dq)])
walk_tree(data, map=m)
exp = """\
1: a
2: |
with
new : line
3: '${abc}'
4: "almost:mapping"
"""
assert round_trip_dump(data) == dedent(exp)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_tag.py 0000775 0000000 0000000 00000007043 14742310675 0023530 0 ustar 00root root 0000000 0000000 # coding: utf-8
import pytest # NOQA
from .roundtrip import round_trip, round_trip_load, YAML
def register_xxx(**kw):
import srsly.ruamel_yaml as yaml
class XXX(yaml.comments.CommentedMap):
@staticmethod
def yaml_dump(dumper, data):
return dumper.represent_mapping(u"!xxx", data)
@classmethod
def yaml_load(cls, constructor, node):
data = cls()
yield data
constructor.construct_mapping(node, data)
yaml.add_constructor(u"!xxx", XXX.yaml_load, constructor=yaml.RoundTripConstructor)
yaml.add_representer(XXX, XXX.yaml_dump, representer=yaml.RoundTripRepresenter)
class TestIndentFailures:
def test_tag(self):
round_trip(
"""\
!!python/object:__main__.Developer
name: Anthon
location: Germany
language: python
"""
)
def test_full_tag(self):
round_trip(
"""\
!!tag:yaml.org,2002:python/object:__main__.Developer
name: Anthon
location: Germany
language: python
"""
)
def test_standard_tag(self):
round_trip(
"""\
!!tag:yaml.org,2002:python/object:map
name: Anthon
location: Germany
language: python
"""
)
def test_Y1(self):
round_trip(
"""\
!yyy
name: Anthon
location: Germany
language: python
"""
)
def test_Y2(self):
round_trip(
"""\
!!yyy
name: Anthon
location: Germany
language: python
"""
)
class TestRoundTripCustom:
def test_X1(self):
register_xxx()
round_trip(
"""\
!xxx
name: Anthon
location: Germany
language: python
"""
)
@pytest.mark.xfail(strict=True)
def test_X_pre_tag_comment(self):
register_xxx()
round_trip(
"""\
-
# hello
!xxx
name: Anthon
location: Germany
language: python
"""
)
@pytest.mark.xfail(strict=True)
def test_X_post_tag_comment(self):
register_xxx()
round_trip(
"""\
- !xxx
# hello
name: Anthon
location: Germany
language: python
"""
)
def test_scalar_00(self):
# https://stackoverflow.com/a/45967047/1307905
round_trip(
"""\
Outputs:
Vpc:
Value: !Ref: vpc # first tag
Export:
Name: !Sub "${AWS::StackName}-Vpc" # second tag
"""
)
class TestIssue201:
def test_encoded_unicode_tag(self):
round_trip_load(
"""
s: !!python/%75nicode 'abc'
"""
)
class TestImplicitTaggedNodes:
def test_scalar(self):
round_trip(
"""\
- !Scalar abcdefg
"""
)
def test_mapping(self):
round_trip(
"""\
- !Mapping {a: 1, b: 2}
"""
)
def test_sequence(self):
yaml = YAML()
yaml.brace_single_entry_mapping_in_flow_sequence = True
yaml.mapping_value_align = True
yaml.round_trip(
"""
- !Sequence [a, {b: 1}, {c: {d: 3}}]
"""
)
def test_sequence2(self):
yaml = YAML()
yaml.mapping_value_align = True
yaml.round_trip(
"""
- !Sequence [a, b: 1, c: {d: 3}]
"""
)
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_version.py 0000775 0000000 0000000 00000010367 14742310675 0024445 0 ustar 00root root 0000000 0000000 # coding: utf-8
import pytest # NOQA
from .roundtrip import dedent, round_trip, round_trip_load
def load(s, version=None):
import srsly.ruamel_yaml # NOQA
return srsly.ruamel_yaml.round_trip_load(dedent(s), version)
class TestVersions:
def test_explicit_1_2(self):
r = load(
"""\
%YAML 1.2
---
- 12:34:56
- 012
- 012345678
- 0o12
- on
- off
- yes
- no
- true
"""
)
assert r[0] == "12:34:56"
assert r[1] == 12
assert r[2] == 12345678
assert r[3] == 10
assert r[4] == "on"
assert r[5] == "off"
assert r[6] == "yes"
assert r[7] == "no"
assert r[8] is True
def test_explicit_1_1(self):
r = load(
"""\
%YAML 1.1
---
- 12:34:56
- 012
- 012345678
- 0o12
- on
- off
- yes
- no
- true
"""
)
assert r[0] == 45296
assert r[1] == 10
assert r[2] == "012345678"
assert r[3] == "0o12"
assert r[4] is True
assert r[5] is False
assert r[6] is True
assert r[7] is False
assert r[8] is True
def test_implicit_1_2(self):
r = load(
"""\
- 12:34:56
- 12:34:56.78
- 012
- 012345678
- 0o12
- on
- off
- yes
- no
- true
"""
)
assert r[0] == "12:34:56"
assert r[1] == "12:34:56.78"
assert r[2] == 12
assert r[3] == 12345678
assert r[4] == 10
assert r[5] == "on"
assert r[6] == "off"
assert r[7] == "yes"
assert r[8] == "no"
assert r[9] is True
def test_load_version_1_1(self):
inp = """\
- 12:34:56
- 12:34:56.78
- 012
- 012345678
- 0o12
- on
- off
- yes
- no
- true
"""
r = load(inp, version="1.1")
assert r[0] == 45296
assert r[1] == 45296.78
assert r[2] == 10
assert r[3] == "012345678"
assert r[4] == "0o12"
assert r[5] is True
assert r[6] is False
assert r[7] is True
assert r[8] is False
assert r[9] is True
class TestIssue62:
# bitbucket issue 62, issue_62
def test_00(self):
import srsly.ruamel_yaml # NOQA
s = dedent(
"""\
{}# Outside flow collection:
- ::vector
- ": - ()"
- Up, up, and away!
- -123
- http://example.com/foo#bar
# Inside flow collection:
- [::vector, ": - ()", "Down, down and away!", -456, http://example.com/foo#bar]
"""
)
with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True)
round_trip(s.format(""), preserve_quotes=True)
def test_00_single_comment(self):
import srsly.ruamel_yaml # NOQA
s = dedent(
"""\
{}# Outside flow collection:
- ::vector
- ": - ()"
- Up, up, and away!
- -123
- http://example.com/foo#bar
- [::vector, ": - ()", "Down, down and away!", -456, http://example.com/foo#bar]
"""
)
with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True)
round_trip(s.format(""), preserve_quotes=True)
# round_trip(s.format('%YAML 1.2\n---\n'), preserve_quotes=True, version=(1, 2))
def test_01(self):
import srsly.ruamel_yaml # NOQA
s = dedent(
"""\
{}[random plain value that contains a ? character]
"""
)
with pytest.raises(srsly.ruamel_yaml.parser.ParserError):
round_trip(s.format("%YAML 1.1\n---\n"), preserve_quotes=True)
round_trip(s.format(""), preserve_quotes=True)
# note the flow seq on the --- line!
round_trip(s.format("%YAML 1.2\n--- "), preserve_quotes=True, version="1.2")
def test_so_45681626(self):
# was not properly parsing
round_trip_load('{"in":{},"out":{}}')
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_yamlfile.py 0000775 0000000 0000000 00000013411 14742310675 0024553 0 ustar 00root root 0000000 0000000 from __future__ import print_function
"""
various test cases for YAML files
"""
import sys
import pytest # NOQA
import platform
from .roundtrip import round_trip, dedent, round_trip_load, round_trip_dump # NOQA
class TestYAML:
def test_backslash(self):
round_trip(
"""
handlers:
static_files: applications/\\1/static/\\2
"""
)
def test_omap_out(self):
# ordereddict mapped to !!omap
from srsly.ruamel_yaml.compat import ordereddict
import srsly.ruamel_yaml # NOQA
x = ordereddict([("a", 1), ("b", 2)])
res = srsly.ruamel_yaml.dump(x, default_flow_style=False)
assert res == dedent(
"""
!!omap
- a: 1
- b: 2
"""
)
def test_omap_roundtrip(self):
round_trip(
"""
!!omap
- a: 1
- b: 2
- c: 3
- d: 4
"""
)
@pytest.mark.skipif(sys.version_info < (2, 7), reason="collections not available")
def test_dump_collections_ordereddict(self):
from collections import OrderedDict
import srsly.ruamel_yaml # NOQA
# OrderedDict mapped to !!omap
x = OrderedDict([("a", 1), ("b", 2)])
res = srsly.ruamel_yaml.dump(
x, Dumper=srsly.ruamel_yaml.RoundTripDumper, default_flow_style=False
)
assert res == dedent(
"""
!!omap
- a: 1
- b: 2
"""
)
@pytest.mark.skipif(
sys.version_info >= (3, 0) or platform.python_implementation() != "CPython",
reason="srsly.ruamel_yaml not available",
)
def test_dump_ruamel_ordereddict(self):
from srsly.ruamel_yaml.compat import ordereddict
import srsly.ruamel_yaml # NOQA
# OrderedDict mapped to !!omap
x = ordereddict([("a", 1), ("b", 2)])
res = srsly.ruamel_yaml.dump(
x, Dumper=srsly.ruamel_yaml.RoundTripDumper, default_flow_style=False
)
assert res == dedent(
"""
!!omap
- a: 1
- b: 2
"""
)
def test_CommentedSet(self):
from srsly.ruamel_yaml.constructor import CommentedSet
s = CommentedSet(["a", "b", "c"])
s.remove("b")
s.add("d")
assert s == CommentedSet(["a", "c", "d"])
s.add("e")
s.add("f")
s.remove("e")
assert s == CommentedSet(["a", "c", "d", "f"])
def test_set_out(self):
# preferable would be the shorter format without the ': null'
import srsly.ruamel_yaml # NOQA
x = set(["a", "b", "c"])
res = srsly.ruamel_yaml.dump(x, default_flow_style=False)
assert res == dedent(
"""
!!set
a: null
b: null
c: null
"""
)
# ordering is not preserved in a set
def test_set_compact(self):
# this format is read and also should be written by default
round_trip(
"""
!!set
? a
? b
? c
"""
)
def test_blank_line_after_comment(self):
round_trip(
"""
# Comment with spaces after it.
a: 1
"""
)
def test_blank_line_between_seq_items(self):
round_trip(
"""
# Seq with empty lines in between items.
b:
- bar
- baz
"""
)
@pytest.mark.skipif(
platform.python_implementation() == "Jython",
reason="Jython throws RepresenterError",
)
def test_blank_line_after_literal_chip(self):
s = """
c:
- |
This item
has a blank line
following it.
- |
To visually separate it from this item.
This item contains a blank line.
"""
d = round_trip_load(dedent(s))
print(d)
round_trip(s)
assert d["c"][0].split("it.")[1] == "\n"
assert d["c"][1].split("line.")[1] == "\n"
@pytest.mark.skipif(
platform.python_implementation() == "Jython",
reason="Jython throws RepresenterError",
)
def test_blank_line_after_literal_keep(self):
""" have to insert an eof marker in YAML to test this"""
s = """
c:
- |+
This item
has a blank line
following it.
- |+
To visually separate it from this item.
This item contains a blank line.
...
"""
d = round_trip_load(dedent(s))
print(d)
round_trip(s)
assert d["c"][0].split("it.")[1] == "\n\n"
assert d["c"][1].split("line.")[1] == "\n\n\n"
@pytest.mark.skipif(
platform.python_implementation() == "Jython",
reason="Jython throws RepresenterError",
)
def test_blank_line_after_literal_strip(self):
s = """
c:
- |-
This item
has a blank line
following it.
- |-
To visually separate it from this item.
This item contains a blank line.
"""
d = round_trip_load(dedent(s))
print(d)
round_trip(s)
assert d["c"][0].split("it.")[1] == ""
assert d["c"][1].split("line.")[1] == ""
def test_load_all_perserve_quotes(self):
import srsly.ruamel_yaml # NOQA
s = dedent(
"""\
a: 'hello'
---
b: "goodbye"
"""
)
data = []
for x in srsly.ruamel_yaml.round_trip_load_all(s, preserve_quotes=True):
data.append(x)
out = srsly.ruamel_yaml.dump_all(data, Dumper=srsly.ruamel_yaml.RoundTripDumper)
print(type(data[0]["a"]), data[0]["a"])
# out = srsly.ruamel_yaml.round_trip_dump_all(data)
print(out)
assert out == s
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_yamlobject.py 0000775 0000000 0000000 00000004721 14742310675 0025106 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function
import sys
import pytest # NOQA
from .roundtrip import save_and_run # NOQA
def test_monster(tmpdir):
program_src = u'''\
import srsly.ruamel_yaml
from textwrap import dedent
class Monster(srsly.ruamel_yaml.YAMLObject):
yaml_tag = u'!Monster'
def __init__(self, name, hp, ac, attacks):
self.name = name
self.hp = hp
self.ac = ac
self.attacks = attacks
def __repr__(self):
return "%s(name=%r, hp=%r, ac=%r, attacks=%r)" % (
self.__class__.__name__, self.name, self.hp, self.ac, self.attacks)
data = srsly.ruamel_yaml.load(dedent("""\\
--- !Monster
name: Cave spider
hp: [2,6] # 2d6
ac: 16
attacks: [BITE, HURT]
"""), Loader=srsly.ruamel_yaml.Loader)
# normal dump, keys will be sorted
assert srsly.ruamel_yaml.dump(data) == dedent("""\\
!Monster
ac: 16
attacks: [BITE, HURT]
hp: [2, 6]
name: Cave spider
""")
'''
assert save_and_run(program_src, tmpdir) == 1
@pytest.mark.skipif(sys.version_info < (3, 0), reason="no __qualname__")
def test_qualified_name00(tmpdir):
"""issue 214"""
program_src = u"""\
from srsly.ruamel_yaml import YAML
from srsly.ruamel_yaml.compat import StringIO
class A:
def f(self):
pass
yaml = YAML(typ='unsafe', pure=True)
yaml.explicit_end = True
buf = StringIO()
yaml.dump(A.f, buf)
res = buf.getvalue()
print('res', repr(res))
assert res == "!!python/name:__main__.A.f ''\\n...\\n"
x = yaml.load(res)
assert x == A.f
"""
assert save_and_run(program_src, tmpdir) == 1
@pytest.mark.skipif(sys.version_info < (3, 0), reason="no __qualname__")
def test_qualified_name01(tmpdir):
"""issue 214"""
from srsly.ruamel_yaml import YAML
import srsly.ruamel_yaml.comments
from srsly.ruamel_yaml.compat import StringIO
with pytest.raises(ValueError):
yaml = YAML(typ="unsafe", pure=True)
yaml.explicit_end = True
buf = StringIO()
yaml.dump(srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor, buf)
res = buf.getvalue()
assert (
res
== "!!python/name:srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor ''\n...\n"
)
x = yaml.load(res)
assert x == srsly.ruamel_yaml.comments.CommentedBase.yaml_anchor
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_z_check_debug_leftovers.py 0000775 0000000 0000000 00000001573 14742310675 0027624 0 ustar 00root root 0000000 0000000 # coding: utf-8
import sys
import pytest # NOQA
from .roundtrip import round_trip_load, round_trip_dump, dedent
class TestLeftOverDebug:
# idea here is to capture round_trip_output via pytest stdout capture
# if there is are any leftover debug statements they should show up
def test_00(self, capsys):
s = dedent(
"""
a: 1
b: []
c: [a, 1]
d: {f: 3.14, g: 42}
"""
)
d = round_trip_load(s)
round_trip_dump(d, sys.stdout)
out, err = capsys.readouterr()
assert out == s
def test_01(self, capsys):
s = dedent(
"""
- 1
- []
- [a, 1]
- {f: 3.14, g: 42}
- - 123
"""
)
d = round_trip_load(s)
round_trip_dump(d, sys.stdout)
out, err = capsys.readouterr()
assert out == s
srsly-release-v2.5.1/srsly/tests/ruamel_yaml/test_z_data.py 0000775 0000000 0000000 00000016227 14742310675 0024223 0 ustar 00root root 0000000 0000000 # coding: utf-8
from __future__ import print_function, unicode_literals
import sys
import pytest # NOQA
import warnings # NOQA
from pathlib import Path
base_path = Path("data") # that is srsly.ruamel_yaml.data
PY2 = sys.version_info[0] == 2
class YAMLData(object):
yaml_tag = "!YAML"
def __init__(self, s):
self._s = s
# Conversion tables for input. E.g. "" is replaced by "\t"
# fmt: off
special = {
'SPC': ' ',
'TAB': '\t',
'---': '---',
'...': '...',
}
# fmt: on
@property
def value(self):
if hasattr(self, "_p"):
return self._p
assert " \n" not in self._s
assert "\t\n" not in self._s
self._p = self._s
for k, v in YAMLData.special.items():
k = "<" + k + ">"
self._p = self._p.replace(k, v)
return self._p
def test_rewrite(self, s):
assert " \n" not in s
assert "\t\n" not in s
for k, v in YAMLData.special.items():
k = "<" + k + ">"
s = s.replace(k, v)
return s
@classmethod
def from_yaml(cls, constructor, node):
from srsly.ruamel_yaml.nodes import MappingNode
if isinstance(node, MappingNode):
return cls(constructor.construct_mapping(node))
return cls(node.value)
class Python(YAMLData):
yaml_tag = "!Python"
class Output(YAMLData):
yaml_tag = "!Output"
class Assert(YAMLData):
yaml_tag = "!Assert"
@property
def value(self):
from srsly.ruamel_yaml.compat import Mapping
if hasattr(self, "_pa"):
return self._pa
if isinstance(self._s, Mapping):
self._s["lines"] = self.test_rewrite(self._s["lines"])
self._pa = self._s
return self._pa
def pytest_generate_tests(metafunc):
test_yaml = []
paths = sorted(base_path.glob("**/*.yaml"))
idlist = []
for path in paths:
stem = path.stem
if stem.startswith(".#"): # skip emacs temporary file
continue
idlist.append(stem)
test_yaml.append([path])
metafunc.parametrize(["yaml"], test_yaml, ids=idlist, scope="class")
class TestYAMLData(object):
def yaml(self, yaml_version=None):
from srsly.ruamel_yaml import YAML
y = YAML()
y.preserve_quotes = True
if yaml_version:
y.version = yaml_version
return y
def docs(self, path):
from srsly.ruamel_yaml import YAML
tyaml = YAML(typ="safe", pure=True)
tyaml.register_class(YAMLData)
tyaml.register_class(Python)
tyaml.register_class(Output)
tyaml.register_class(Assert)
return list(tyaml.load_all(path))
def yaml_load(self, value, yaml_version=None):
yaml = self.yaml(yaml_version=yaml_version)
data = yaml.load(value)
return yaml, data
def round_trip(self, input, output=None, yaml_version=None):
from srsly.ruamel_yaml.compat import StringIO
yaml, data = self.yaml_load(input.value, yaml_version=yaml_version)
buf = StringIO()
yaml.dump(data, buf)
expected = input.value if output is None else output.value
value = buf.getvalue()
if PY2:
value = value.decode("utf-8")
print("value", value)
# print('expected', expected)
assert value == expected
def load_assert(self, input, confirm, yaml_version=None):
from srsly.ruamel_yaml.compat import Mapping
d = self.yaml_load(input.value, yaml_version=yaml_version)[1] # NOQA
print("confirm.value", confirm.value, type(confirm.value))
if isinstance(confirm.value, Mapping):
r = range(confirm.value["range"])
lines = confirm.value["lines"].splitlines()
for idx in r: # NOQA
for line in lines:
line = "assert " + line
print(line)
exec(line)
else:
for line in confirm.value.splitlines():
line = "assert " + line
print(line)
exec(line)
def run_python(self, python, data, tmpdir):
from .roundtrip import save_and_run
assert save_and_run(python.value, base_dir=tmpdir, output=data.value) == 0
# this is executed by pytest the methods with names not starting with test_
# are helpers
def test_yaml_data(self, yaml, tmpdir):
from srsly.ruamel_yaml.compat import Mapping
idx = 0
typ = None
yaml_version = None
docs = self.docs(yaml)
if isinstance(docs[0], Mapping):
d = docs[0]
typ = d.get("type")
yaml_version = d.get("yaml_version")
if "python" in d:
if not check_python_version(d["python"]):
pytest.skip("unsupported version")
idx += 1
data = output = confirm = python = None
for doc in docs[idx:]:
if isinstance(doc, Output):
output = doc
elif isinstance(doc, Assert):
confirm = doc
elif isinstance(doc, Python):
python = doc
if typ is None:
typ = "python_run"
elif isinstance(doc, YAMLData):
data = doc
else:
print("no handler for type:", type(doc), repr(doc))
raise AssertionError()
if typ is None:
if data is not None and output is not None:
typ = "rt"
elif data is not None and confirm is not None:
typ = "load_assert"
else:
assert data is not None
typ = "rt"
print("type:", typ)
if data is not None:
print("data:", data.value, end="")
print("output:", output.value if output is not None else output)
if typ == "rt":
self.round_trip(data, output, yaml_version=yaml_version)
elif typ == "python_run":
self.run_python(python, output if output is not None else data, tmpdir)
elif typ == "load_assert":
self.load_assert(data, confirm, yaml_version=yaml_version)
else:
print("\nrun type unknown:", typ)
raise AssertionError()
def check_python_version(match, current=None):
"""
version indication, return True if version matches.
match should be something like 3.6+, or [2.7, 3.3] etc. Floats
are converted to strings. Single values are made into lists.
"""
if current is None:
current = list(sys.version_info[:3])
if not isinstance(match, list):
match = [match]
for m in match:
minimal = False
if isinstance(m, float):
m = str(m)
if m.endswith("+"):
minimal = True
m = m[:-1]
# assert m[0].isdigit()
# assert m[-1].isdigit()
m = [int(x) for x in m.split(".")]
current_len = current[: len(m)]
# print(m, current, current_len)
if minimal:
if current_len >= m:
return True
else:
if current_len == m:
return True
return False
srsly-release-v2.5.1/srsly/tests/test_json_api.py 0000664 0000000 0000000 00000020143 14742310675 0022241 0 ustar 00root root 0000000 0000000 import pytest
from io import StringIO
from pathlib import Path
import gzip
import numpy
from .._json_api import (
read_json,
write_json,
read_jsonl,
write_jsonl,
read_gzip_jsonl,
write_gzip_jsonl,
)
from .._json_api import write_gzip_json, json_dumps, is_json_serializable
from .._json_api import json_loads
from ..util import force_string
from .util import make_tempdir
def test_json_dumps_sort_keys():
data = {"a": 1, "c": 3, "b": 2}
result = json_dumps(data, sort_keys=True)
assert result == '{"a":1,"b":2,"c":3}'
def test_read_json_file():
file_contents = '{\n "hello": "world"\n}'
with make_tempdir({"tmp.json": file_contents}) as temp_dir:
file_path = temp_dir / "tmp.json"
assert file_path.exists()
data = read_json(file_path)
assert len(data) == 1
assert data["hello"] == "world"
def test_read_json_file_invalid():
file_contents = '{\n "hello": world\n}'
with make_tempdir({"tmp.json": file_contents}) as temp_dir:
file_path = temp_dir / "tmp.json"
assert file_path.exists()
with pytest.raises(ValueError):
read_json(file_path)
def test_read_json_stdin(monkeypatch):
input_data = '{\n "hello": "world"\n}'
monkeypatch.setattr("sys.stdin", StringIO(input_data))
data = read_json("-")
assert len(data) == 1
assert data["hello"] == "world"
def test_write_json_file():
data = {"hello": "world", "test": 123}
# Provide two expected options, depending on how keys are ordered
expected = [
'{\n "hello":"world",\n "test":123\n}',
'{\n "test":123,\n "hello":"world"\n}',
]
with make_tempdir() as temp_dir:
file_path = temp_dir / "tmp.json"
write_json(file_path, data)
with Path(file_path).open("r", encoding="utf8") as f:
assert f.read() in expected
def test_write_json_file_gzip():
data = {"hello": "world", "test": 123}
# Provide two expected options, depending on how keys are ordered
expected = [
'{\n "hello":"world",\n "test":123\n}',
'{\n "test":123,\n "hello":"world"\n}',
]
with make_tempdir() as temp_dir:
file_path = force_string(temp_dir / "tmp.json")
write_gzip_json(file_path, data)
with gzip.open(file_path, "r") as f:
assert f.read().decode("utf8") in expected
def test_write_json_stdout(capsys):
data = {"hello": "world", "test": 123}
# Provide two expected options, depending on how keys are ordered
expected = [
'{\n "hello":"world",\n "test":123\n}\n',
'{\n "test":123,\n "hello":"world"\n}\n',
]
write_json("-", data)
captured = capsys.readouterr()
assert captured.out in expected
def test_read_jsonl_file():
file_contents = '{"hello": "world"}\n{"test": 123}'
with make_tempdir({"tmp.json": file_contents}) as temp_dir:
file_path = temp_dir / "tmp.json"
assert file_path.exists()
data = read_jsonl(file_path)
# Make sure this returns a generator, not just a list
assert not hasattr(data, "__len__")
data = list(data)
assert len(data) == 2
assert len(data[0]) == 1
assert len(data[1]) == 1
assert data[0]["hello"] == "world"
assert data[1]["test"] == 123
def test_read_jsonl_file_invalid():
file_contents = '{"hello": world}\n{"test": 123}'
with make_tempdir({"tmp.json": file_contents}) as temp_dir:
file_path = temp_dir / "tmp.json"
assert file_path.exists()
with pytest.raises(ValueError):
data = list(read_jsonl(file_path))
data = list(read_jsonl(file_path, skip=True))
assert len(data) == 1
assert len(data[0]) == 1
assert data[0]["test"] == 123
def test_read_jsonl_stdin(monkeypatch):
input_data = '{"hello": "world"}\n{"test": 123}'
monkeypatch.setattr("sys.stdin", StringIO(input_data))
data = read_jsonl("-")
# Make sure this returns a generator, not just a list
assert not hasattr(data, "__len__")
data = list(data)
assert len(data) == 2
assert len(data[0]) == 1
assert len(data[1]) == 1
assert data[0]["hello"] == "world"
assert data[1]["test"] == 123
def test_write_jsonl_file():
data = [{"hello": "world"}, {"test": 123}]
with make_tempdir() as temp_dir:
file_path = temp_dir / "tmp.json"
write_jsonl(file_path, data)
with Path(file_path).open("r", encoding="utf8") as f:
assert f.read() == '{"hello":"world"}\n{"test":123}\n'
def test_write_jsonl_file_append():
data = [{"hello": "world"}, {"test": 123}]
expected = '{"hello":"world"}\n{"test":123}\n\n{"hello":"world"}\n{"test":123}\n'
with make_tempdir() as temp_dir:
file_path = temp_dir / "tmp.json"
write_jsonl(file_path, data)
write_jsonl(file_path, data, append=True)
with Path(file_path).open("r", encoding="utf8") as f:
assert f.read() == expected
def test_write_jsonl_file_append_no_new_line():
data = [{"hello": "world"}, {"test": 123}]
expected = '{"hello":"world"}\n{"test":123}\n{"hello":"world"}\n{"test":123}\n'
with make_tempdir() as temp_dir:
file_path = temp_dir / "tmp.json"
write_jsonl(file_path, data)
write_jsonl(file_path, data, append=True, append_new_line=False)
with Path(file_path).open("r", encoding="utf8") as f:
assert f.read() == expected
def test_write_jsonl_stdout(capsys):
data = [{"hello": "world"}, {"test": 123}]
write_jsonl("-", data)
captured = capsys.readouterr()
assert captured.out == '{"hello":"world"}\n{"test":123}\n'
@pytest.mark.parametrize(
"obj,expected",
[
(["a", "b", 1, 2], True),
({"a": "b", "c": 123}, True),
("hello", True),
(lambda x: x, False),
],
)
def test_is_json_serializable(obj, expected):
assert is_json_serializable(obj) == expected
@pytest.mark.parametrize(
"obj,expected",
[
("-32", -32),
("32", 32),
("0", 0),
("-0", 0),
],
)
def test_json_loads_number_string(obj, expected):
assert json_loads(obj) == expected
@pytest.mark.parametrize(
"obj",
["HI", "-", "-?", "?!", "THIS IS A STRING"],
)
def test_json_loads_raises(obj):
with pytest.raises(ValueError):
json_loads(obj)
def test_unsupported_type_error():
f = numpy.float32()
with pytest.raises(TypeError):
s = json_dumps(f)
def test_write_jsonl_gzip():
"""Tests writing data to a gzipped .jsonl file."""
data = [{"hello": "world"}, {"test": 123}]
expected = ['{"hello":"world"}\n', '{"test":123}\n']
with make_tempdir() as temp_dir:
file_path = temp_dir / "tmp.json"
write_gzip_jsonl(file_path, data)
with gzip.open(file_path, "r") as f:
assert [line.decode("utf8") for line in f.readlines()] == expected
def test_write_jsonl_gzip_append():
"""Tests appending data to a gzipped .jsonl file."""
data = [{"hello": "world"}, {"test": 123}]
expected = [
'{"hello":"world"}\n',
'{"test":123}\n',
"\n",
'{"hello":"world"}\n',
'{"test":123}\n',
]
with make_tempdir() as temp_dir:
file_path = temp_dir / "tmp.json"
write_gzip_jsonl(file_path, data)
write_gzip_jsonl(file_path, data, append=True)
with gzip.open(file_path, "r") as f:
assert [line.decode("utf8") for line in f.readlines()] == expected
def test_read_jsonl_gzip():
"""Tests reading data from a gzipped .jsonl file."""
file_contents = [{"hello": "world"}, {"test": 123}]
with make_tempdir() as temp_dir:
file_path = temp_dir / "tmp.json"
with gzip.open(file_path, "w") as f:
f.writelines(
[(json_dumps(line) + "\n").encode("utf-8") for line in file_contents]
)
assert file_path.exists()
data = read_gzip_jsonl(file_path)
# Make sure this returns a generator, not just a list
assert not hasattr(data, "__len__")
data = list(data)
assert len(data) == 2
assert len(data[0]) == 1
assert len(data[1]) == 1
assert data[0]["hello"] == "world"
assert data[1]["test"] == 123
srsly-release-v2.5.1/srsly/tests/test_msgpack_api.py 0000664 0000000 0000000 00000006760 14742310675 0022726 0 ustar 00root root 0000000 0000000 import pytest
from pathlib import Path
import datetime
from mock import patch
import numpy
from .._msgpack_api import read_msgpack, write_msgpack
from .._msgpack_api import msgpack_loads, msgpack_dumps
from .._msgpack_api import msgpack_encoders, msgpack_decoders
from .util import make_tempdir
def test_msgpack_dumps():
data = {"hello": "world", "test": 123}
expected = [b"\x82\xa5hello\xa5world\xa4test{", b"\x82\xa4test{\xa5hello\xa5world"]
msg = msgpack_dumps(data)
assert msg in expected
def test_msgpack_loads():
msg = b"\x82\xa5hello\xa5world\xa4test{"
data = msgpack_loads(msg)
assert len(data) == 2
assert data["hello"] == "world"
assert data["test"] == 123
def test_read_msgpack_file():
file_contents = b"\x81\xa5hello\xa5world"
with make_tempdir({"tmp.msg": file_contents}, mode="wb") as temp_dir:
file_path = temp_dir / "tmp.msg"
assert file_path.exists()
data = read_msgpack(file_path)
assert len(data) == 1
assert data["hello"] == "world"
def test_read_msgpack_file_invalid():
file_contents = b"\xa5hello\xa5world"
with make_tempdir({"tmp.msg": file_contents}, mode="wb") as temp_dir:
file_path = temp_dir / "tmp.msg"
assert file_path.exists()
with pytest.raises(ValueError):
read_msgpack(file_path)
def test_write_msgpack_file():
data = {"hello": "world", "test": 123}
expected = [b"\x82\xa5hello\xa5world\xa4test{", b"\x82\xa4test{\xa5hello\xa5world"]
with make_tempdir(mode="wb") as temp_dir:
file_path = temp_dir / "tmp.msg"
write_msgpack(file_path, data)
with Path(file_path).open("rb") as f:
assert f.read() in expected
@patch("srsly.msgpack._msgpack_numpy.np", None)
@patch("srsly.msgpack._msgpack_numpy.has_numpy", False)
def test_msgpack_without_numpy():
"""Test that msgpack works without numpy and raises correct errors (e.g.
when serializing datetime objects, the error should be msgpack's TypeError,
not a "'np' is not defined error")."""
with pytest.raises(TypeError):
msgpack_loads(msgpack_dumps(datetime.datetime.now()))
def test_msgpack_custom_encoder_decoder():
class CustomObject:
def __init__(self, value):
self.value = value
def serialize_obj(obj, chain=None):
if isinstance(obj, CustomObject):
return {"__custom__": obj.value}
return obj if chain is None else chain(obj)
def deserialize_obj(obj, chain=None):
if "__custom__" in obj:
return CustomObject(obj["__custom__"])
return obj if chain is None else chain(obj)
data = {"a": 123, "b": CustomObject({"foo": "bar"})}
with pytest.raises(TypeError):
msgpack_dumps(data)
# Register custom encoders/decoders to handle CustomObject
msgpack_encoders.register("custom_object", func=serialize_obj)
msgpack_decoders.register("custom_object", func=deserialize_obj)
bytes_data = msgpack_dumps(data)
new_data = msgpack_loads(bytes_data)
assert new_data["a"] == 123
assert isinstance(new_data["b"], CustomObject)
assert new_data["b"].value == {"foo": "bar"}
# Test that it also works with combinations of encoders/decoders (e.g. numpy)
data = {"a": numpy.zeros((1, 2, 3)), "b": CustomObject({"foo": "bar"})}
bytes_data = msgpack_dumps(data)
new_data = msgpack_loads(bytes_data)
assert isinstance(new_data["a"], numpy.ndarray)
assert isinstance(new_data["b"], CustomObject)
assert new_data["b"].value == {"foo": "bar"}
srsly-release-v2.5.1/srsly/tests/test_pickle_api.py 0000664 0000000 0000000 00000001546 14742310675 0022545 0 ustar 00root root 0000000 0000000 from .._pickle_api import pickle_dumps, pickle_loads
def test_pickle_dumps():
data = {"hello": "world", "test": 123}
expected = [
b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00}\x94(\x8c\x05hello\x94\x8c\x05world\x94\x8c\x04test\x94K{u.",
b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00}\x94(\x8c\x04test\x94K{\x8c\x05hello\x94\x8c\x05world\x94u.",
b"\x80\x02}q\x00(X\x04\x00\x00\x00testq\x01K{X\x05\x00\x00\x00helloq\x02X\x05\x00\x00\x00worldq\x03u.",
b"\x80\x05\x95\x1e\x00\x00\x00\x00\x00\x00\x00}\x94(\x8c\x05hello\x94\x8c\x05world\x94\x8c\x04test\x94K{u.",
]
msg = pickle_dumps(data)
assert msg in expected
def test_pickle_loads():
msg = pickle_dumps({"hello": "world", "test": 123})
data = pickle_loads(msg)
assert len(data) == 2
assert data["hello"] == "world"
assert data["test"] == 123
srsly-release-v2.5.1/srsly/tests/test_yaml_api.py 0000664 0000000 0000000 00000006132 14742310675 0022234 0 ustar 00root root 0000000 0000000 from io import StringIO
from pathlib import Path
import pytest
from .._yaml_api import yaml_dumps, yaml_loads, read_yaml, write_yaml
from .._yaml_api import is_yaml_serializable
from ..ruamel_yaml.comments import CommentedMap
from .util import make_tempdir
def test_yaml_dumps():
data = {"a": [1, "hello"], "b": {"foo": "bar", "baz": [10.5, 120]}}
result = yaml_dumps(data)
expected = "a:\n - 1\n - hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n"
assert result == expected
def test_yaml_dumps_indent():
data = {"a": [1, "hello"], "b": {"foo": "bar", "baz": [10.5, 120]}}
result = yaml_dumps(data, indent_mapping=2, indent_sequence=2, indent_offset=0)
expected = "a:\n- 1\n- hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n"
assert result == expected
def test_yaml_loads():
data = "a:\n- 1\n- hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n"
result = yaml_loads(data)
# Check that correct loader is used and result is regular dict, not the
# custom ruamel.yaml "ordereddict" class
assert not isinstance(result, CommentedMap)
assert result == {"a": [1, "hello"], "b": {"foo": "bar", "baz": [10.5, 120]}}
def test_read_yaml_file():
file_contents = "a:\n- 1\n- hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n"
with make_tempdir({"tmp.yaml": file_contents}) as temp_dir:
file_path = temp_dir / "tmp.yaml"
assert file_path.exists()
data = read_yaml(file_path)
assert len(data) == 2
assert data["a"] == [1, "hello"]
def test_read_yaml_file_invalid():
file_contents = "a: - 1\n- hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n"
with make_tempdir({"tmp.yaml": file_contents}) as temp_dir:
file_path = temp_dir / "tmp.yaml"
assert file_path.exists()
with pytest.raises(ValueError):
read_yaml(file_path)
def test_read_yaml_stdin(monkeypatch):
input_data = "a:\n - 1\n - hello\nb:\n foo: bar\n baz:\n - 10.5\n - 120\n"
monkeypatch.setattr("sys.stdin", StringIO(input_data))
data = read_yaml("-")
assert len(data) == 2
assert data["a"] == [1, "hello"]
def test_write_yaml_file():
data = {"hello": "world", "test": [123, 456]}
expected = "hello: world\ntest:\n - 123\n - 456\n"
with make_tempdir() as temp_dir:
file_path = temp_dir / "tmp.yaml"
write_yaml(file_path, data)
with Path(file_path).open("r", encoding="utf8") as f:
assert f.read() == expected
def test_write_yaml_stdout(capsys):
data = {"hello": "world", "test": [123, 456]}
expected = "hello: world\ntest:\n - 123\n - 456\n\n"
write_yaml("-", data)
captured = capsys.readouterr()
assert captured.out == expected
@pytest.mark.parametrize(
"obj,expected",
[
(["a", "b", 1, 2], True),
({"a": "b", "c": 123}, True),
("hello", True),
(lambda x: x, False),
({"a": lambda x: x}, False),
],
)
def test_is_yaml_serializable(obj, expected):
assert is_yaml_serializable(obj) == expected
# Check again to be sure it's consistent
assert is_yaml_serializable(obj) == expected
srsly-release-v2.5.1/srsly/tests/ujson/ 0000775 0000000 0000000 00000000000 14742310675 0020164 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/ujson/334-reproducer.json 0000664 0000000 0000000 00000071241 14742310675 0023545 0 ustar 00root root 0000000 0000000 {
"ak.somestring.internal.Shadow": {
"id": 33300002,
"init_state": "(bk.action.array.Make, (bk.action.i32.Const, 0))",
"child": {
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Collection": {
"id": 33300001,
"snap": "center",
"direction": "row",
"children": [
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#2c8932"
}
}
}
},
"children": [
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"children": [
{
"ls.components.Image": {
"media_id": "10156403921218138",
"preview_url": "https://scontent.xx.whoaa.net/v/t1.0-9/51099660_10156403921233138_3677795704043995136_n.jpg?_nc_cat=102&_nc_log=1&_nc_oc=AQk3Td-w9KpopLL2N1jgZ4WDMuxUyuGY3ZvY4mDSCk8W9-GjsFPi2S4gVQk0Y3A5ZaaQf7ASvQ2s_eR85kTmFvr0&_nc_ad=z-m&_nc_cid=0&_nc_zor=9&_nc_ht=scontent.xx&oh=fb16b0d60b13817a505f583cc9dad1eb&oe=5CBCDB46",
"height": 278,
"width": 156
}
}
],
"_style": {
"flex": {
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"flex_direction": "row",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#ffffff"
}
}
}
},
"_style": {
"flex": {
"margin_right": "4dp",
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#ffffff"
}
}
}
},
"_style": {
"flex": {
"margin_right": "4dp",
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "row",
"align_items": "stretch",
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#ffffff"
}
}
}
},
"children": [
{
"ak.somestring.Flexbox": {
"id": 33300004,
"_style": {
"flex": {
"grow": 1
}
}
}
}
],
"_style": {
"flex": {
"margin_right": "4dp",
"grow": 1
}
}
}
}
],
"_style": {
"flex": {
"height": "2dp",
"margin_left": "4dp"
}
}
}
}
],
"_style": {
"flex": {
"position_type": "absolute",
"left": "0dp",
"top": "10dp",
"margin_top": "10dp",
"right": "0dp",
"height": "2dp",
"width": "100%"
}
}
}
}
],
"_style": {
"flex": {
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"align_items": "flex_start",
"children": [
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"corner_radius": "17dp"
}
},
"children": [
{
"ls.components.Image": {
"media_id": "10156403921218138",
"preview_url": "https://scontent.xx.whoaa.net/v/t1.0-9/51099660_10156403921233138_3677795704043995136_n.jpg?_nc_cat=102&_nc_log=1&_nc_oc=AQk3Td-w9KpopLL2N1jgZ4WDMuxUyuGY3ZvY4mDSCk8W9-GjsFPi2S4gVQk0Y3A5ZaaQf7ASvQ2s_eR85kTmFvr0&_nc_ad=z-m&_nc_cid=0&_nc_zor=9&_nc_ht=scontent.xx&oh=fb16b0d60b13817a505f583cc9dad1eb&oe=5CBCDB46",
"height": 34,
"width": 34,
"_style": {
"flex": {
"width": "34dp",
"height": "34dp"
}
}
}
}
],
"_style": {
"flex": {
"margin_right": "12dp",
"width": "34dp",
"height": "34dp"
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "flex_start",
"children": [
{
"ak.somestring.RichText": {
"children": [
{
"ak.somestring.TextSpan": {
"text": "eric",
"text_size": "15sp",
"text_style": "bold",
"text_color": "#ffffff"
}
}
],
"_style": {
"flex": {
"margin_bottom": "2dp",
"width": "100%"
}
}
}
},
{
"ak.somestring.RichText": {
"children": [
{
"ak.somestring.TextSpan": {
"text": "8h",
"text_size": "13sp",
"text_style": "normal",
"text_color": "#ffffff"
}
}
],
"_style": {
"flex": {
"width": "100%"
}
}
}
}
],
"_style": {
"flex": {
"width": "100%",
"height": "100%"
}
}
}
}
],
"_style": {
"flex": {
"position_type": "absolute",
"top": "30dp",
"left": "10dp",
"height": "48dp"
}
}
}
},
{
"ak.somestring.Flexbox": {
"children": [
{
"ls.components.StoriesReplyBar": {}
}
],
"_style": {
"flex": {
"width": "100%",
"height": "45dp",
"margin_top": "auto",
"margin_bottom": "auto"
}
}
}
}
],
"_style": {
"flex": {
"position_type": "absolute",
"width": "100%",
"height": "100%",
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"children": [
{
"ls.components.Image": {
"media_id": "10101230968216658",
"preview_url": "https://scontent.xx.whoaa.net/v/t1.0-9/50800535_10101230968226638_6755212111762161664_n.jpg?_nc_cat=101&_nc_log=1&_nc_oc=AQmKcqYvt6DI7aeGk3k_oF6RHSVZkUg7f9hnBCWilyaOGdCWO0-u9_zssC5qGvca6wqsrz3AP0y1RPLPiZj8ycCv&_nc_ad=z-m&_nc_cid=0&_nc_zor=9&_nc_ht=scontent.xx&oh=2fffbab8f0a102d196454ee0138c1850&oe=5CC15206",
"height": 278,
"width": 156
}
}
],
"_style": {
"flex": {
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"flex_direction": "row",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#ffffff"
}
}
}
},
"_style": {
"flex": {
"margin_right": "4dp",
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "row",
"align_items": "stretch",
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#ffffff"
}
}
}
},
"children": [
{
"ak.somestring.Flexbox": {
"id": 33300005,
"_style": {
"flex": {
"grow": 1
}
}
}
}
],
"_style": {
"flex": {
"margin_right": "4dp",
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#cccccc"
}
}
}
},
"_style": {
"flex": {
"margin_right": "4dp",
"grow": 1
}
}
}
}
],
"_style": {
"flex": {
"height": "2dp",
"margin_left": "4dp"
}
}
}
}
],
"_style": {
"flex": {
"position_type": "absolute",
"left": "0dp",
"top": "10dp",
"margin_top": "10dp",
"right": "0dp",
"height": "2dp",
"width": "100%"
}
}
}
}
],
"_style": {
"flex": {
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"align_items": "flex_start",
"children": [
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"corner_radius": "17dp"
}
},
"children": [
{
"ls.components.Image": {
"media_id": "10101230968216658",
"preview_url": "https://scontent.xx.whoaa.net/v/t1.0-9/50800535_10101230968226638_6755212111762161664_n.jpg?_nc_cat=101&_nc_log=1&_nc_oc=AQmKcqYvt6DI7aeGk3k_oF6RHSVZkUg7f9hnBCWilyaOGdCWO0-u9_zssC5qGvca6wqsrz3AP0y1RPLPiZj8ycCv&_nc_ad=z-m&_nc_cid=0&_nc_zor=9&_nc_ht=scontent.xx&oh=2fffbab8f0a102d196454ee0138c1850&oe=5CC15206",
"height": 34,
"width": 34,
"_style": {
"flex": {
"width": "34dp",
"height": "34dp"
}
}
}
}
],
"_style": {
"flex": {
"margin_right": "12dp",
"width": "34dp",
"height": "34dp"
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "flex_start",
"children": [
{
"ak.somestring.RichText": {
"children": [
{
"ak.somestring.TextSpan": {
"text": "eric",
"text_size": "15sp",
"text_style": "bold",
"text_color": "#ffffff"
}
}
],
"_style": {
"flex": {
"margin_bottom": "2dp",
"width": "100%"
}
}
}
},
{
"ak.somestring.RichText": {
"children": [
{
"ak.somestring.TextSpan": {
"text": "2h",
"text_size": "13sp",
"text_style": "normal",
"text_color": "#ffffff"
}
}
],
"_style": {
"flex": {
"width": "100%"
}
}
}
}
],
"_style": {
"flex": {
"width": "100%",
"height": "100%"
}
}
}
}
],
"_style": {
"flex": {
"position_type": "absolute",
"top": "30dp",
"left": "10dp",
"height": "48dp"
}
}
}
},
{
"ak.somestring.Flexbox": {
"children": [
{
"ls.components.StoriesReplyBar": {}
}
],
"_style": {
"flex": {
"width": "100%",
"height": "45dp",
"margin_top": "auto",
"margin_bottom": "auto"
}
}
}
}
],
"_style": {
"flex": {
"position_type": "absolute",
"width": "100%",
"height": "100%",
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"children": [
{
"ls.components.Video": {
"media_id": "10156395664922983",
"video_url": "https://video.xx.whoaa.net/v/t42.9040-2/51636103_316525608877874_407931582842667008_n.mp4?_nc_cat=109&efg=eyJ2ZW5jb2RlX3RhZyI6InN2ZV9oZCJ9&_nc_log=1&_nc_oc=AQm6aMctRAFdMe3C66upF2JulQP4mV3Hd4THkueZex952PR389F6Ay9XHm1S40dV1x7M1I-fAW5y3iH7JlQ3MgDM&_nc_ht=video.xx&oh=e17b1f7ec67619d57a5b1cda5e076fef&oe=5C587F7D",
"preview_url": "https://scontent.xx.whoaa.net/v/t15.5256-10/s960x960/51767715_10156395667952983_4168426706077483008_n.jpg?_nc_cat=104&_nc_log=1&_nc_oc=AQnVwEZk2vG8Q3TcoR0SxdXSi8rL_GaST2aH3i9auDcDnJNTRKvuYEFfd_qKGBhmD4-bo-f8BY5j9jHyit765O7P&_nc_ad=z-m&_nc_cid=0&_nc_zor=9&_nc_ht=scontent.xx&oh=9a17e4bcf8a2a9aabc21d2ecf9f8611b&oe=5CB3D14B",
"show_media_play_button": false,
"media_height": 960,
"media_width": 540
}
}
],
"_style": {
"flex": {
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"flex_direction": "row",
"align_items": "stretch",
"children": [
{
"ak.somestring.Flexbox": {
"flex_direction": "row",
"align_items": "stretch",
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#ffffff"
}
}
}
},
"children": [
{
"ak.somestring.Flexbox": {
"id": 33300006,
"_style": {
"flex": {
"grow": 1
}
}
}
}
],
"_style": {
"flex": {
"margin_right": "4dp",
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#cccccc"
}
}
}
},
"_style": {
"flex": {
"margin_right": "4dp",
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"background": {
"ak.somestring.ColorDrawable": {
"color": "#cccccc"
}
}
}
},
"_style": {
"flex": {
"margin_right": "4dp",
"grow": 1
}
}
}
}
],
"_style": {
"flex": {
"height": "2dp",
"margin_left": "4dp"
}
}
}
}
],
"_style": {
"flex": {
"position_type": "absolute",
"left": "0dp",
"top": "10dp",
"margin_top": "10dp",
"right": "0dp",
"height": "2dp",
"width": "100%"
}
}
}
}
],
"_style": {
"flex": {
"grow": 1
}
}
}
},
{
"ak.somestring.Flexbox": {
"align_items": "flex_start",
"children": [
{
"ak.somestring.Flexbox": {
"decoration": {
"ak.somestring.BoxDecoration": {
"corner_radius": "17dp"
}
},
"children": [
{
"ls.components.Image": {
"media_id": "10156395664922983",
"height": 34,
"width": 34,
"_style": {
"flex": {
"width": "34dp",
"height": "34dp"
}
}
}
}
],
"_style": {
"flex": {
"margin_right": "12dp",
"width": "34dp",
"height": "34dp"
}
}
}
},
{
"ak.somestring.Flexbox": {
"flex_direction": "column",
"align_items": "flex_start",
"children": [
{
"ak.somestring.RichText": {
"children": [
{
"ak.somestring.TextSpan": {
"text": "eric",
"text_size": "15sp",
"text_style": "bold",
"text_color": "#ffffff"
}
}
],
"_style": {
"flex": {
"margin_bottom": "2dp",
"width": "100%"
}
}
}
},
{
"ak.somestring.RichText": {
"children": [
{
"ak.somestring.TextSpan": {
"text": "20h",
"text_size": "13sp",
"text_style": "normal",
"text_color": "#ffffff"
}
}
],
"_style": {
"flex": {
"width": "100%"
}
}
}
}
],
"_style": {
"flex": {
"width": "100%",
"height": "100%"
}
}
}
}
],
"_style": {
"flex": {
"position_type": "absolute",
"top": "30dp",
"left": "10dp",
"height": "48dp"
}
}
}
},
{
"ak.somestring.Flexbox": {
"children": [
{
"ls.components.StoriesReplyBar": {}
}
],
"_style": {
"flex": {
"width": "100%",
"height": "45dp",
"margin_top": "auto",
"margin_bottom": "auto"
}
}
}
}
],
"_style": {
"flex": {
"position_type": "absolute",
"width": "100%",
"height": "100%",
"grow": 1
}
}
}
}
],
"_style": {
"flex": {
"width": "100%",
"height": "100%"
}
}
}
}
],
"_style": {
"flex": {
"height": "100%"
}
}
}
}
]
}
}
}
}
srsly-release-v2.5.1/srsly/tests/ujson/__init__.py 0000664 0000000 0000000 00000000000 14742310675 0022263 0 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/tests/ujson/test_ujson.py 0000664 0000000 0000000 00000101274 14742310675 0022740 0 ustar 00root root 0000000 0000000 import decimal
import json
import math
import sys
import unittest
import pytest
from io import StringIO
from pathlib import Path
from srsly import ujson
json_unicode = json.dumps
class UltraJSONTests(unittest.TestCase):
def test_encodeDecimal(self):
sut = decimal.Decimal("1337.1337")
encoded = ujson.encode(sut)
decoded = ujson.decode(encoded)
self.assertEqual(decoded, 1337.1337)
def test_encodeStringConversion(self):
input = "A string \\ / \b \f \n \r \t &"
not_html_encoded = '"A string \\\\ \\/ \\b \\f \\n \\r \\t <\\/script> &"'
html_encoded = (
'"A string \\\\ \\/ \\b \\f \\n \\r \\t \\u003c\\/script\\u003e \\u0026"'
)
not_slashes_escaped = '"A string \\\\ / \\b \\f \\n \\r \\t &"'
def helper(expected_output, **encode_kwargs):
output = ujson.encode(input, **encode_kwargs)
self.assertEqual(output, expected_output)
if encode_kwargs.get("escape_forward_slashes", True):
self.assertEqual(input, json.loads(output))
self.assertEqual(input, ujson.decode(output))
# Default behavior assumes encode_html_chars=False.
helper(not_html_encoded, ensure_ascii=True)
helper(not_html_encoded, ensure_ascii=False)
# Make sure explicit encode_html_chars=False works.
helper(not_html_encoded, ensure_ascii=True, encode_html_chars=False)
helper(not_html_encoded, ensure_ascii=False, encode_html_chars=False)
# Make sure explicit encode_html_chars=True does the encoding.
helper(html_encoded, ensure_ascii=True, encode_html_chars=True)
helper(html_encoded, ensure_ascii=False, encode_html_chars=True)
# Do escape forward slashes if disabled.
helper(not_slashes_escaped, escape_forward_slashes=False)
def testWriteEscapedString(self):
self.assertEqual(
"\"\\u003cimg src='\\u0026amp;'\\/\\u003e\"",
ujson.dumps("
", encode_html_chars=True),
)
def test_doubleLongIssue(self):
sut = {"a": -4342969734183514}
encoded = json.dumps(sut)
decoded = json.loads(encoded)
self.assertEqual(sut, decoded)
encoded = ujson.encode(sut)
decoded = ujson.decode(encoded)
self.assertEqual(sut, decoded)
def test_doubleLongDecimalIssue(self):
sut = {"a": -12345678901234.56789012}
encoded = json.dumps(sut)
decoded = json.loads(encoded)
self.assertEqual(sut, decoded)
encoded = ujson.encode(sut)
decoded = ujson.decode(encoded)
self.assertEqual(sut, decoded)
def test_encodeDecodeLongDecimal(self):
sut = {"a": -528656961.4399388}
encoded = ujson.dumps(sut)
ujson.decode(encoded)
def test_decimalDecodeTest(self):
sut = {"a": 4.56}
encoded = ujson.encode(sut)
decoded = ujson.decode(encoded)
self.assertAlmostEqual(sut[u"a"], decoded[u"a"])
def test_encodeDictWithUnicodeKeys(self):
input = {
"key1": "value1",
"key1": "value1",
"key1": "value1",
"key1": "value1",
"key1": "value1",
"key1": "value1",
}
ujson.encode(input)
input = {
"بن": "value1",
"بن": "value1",
"بن": "value1",
"بن": "value1",
"بن": "value1",
"بن": "value1",
"بن": "value1",
}
ujson.encode(input)
def test_encodeDoubleConversion(self):
input = math.pi
output = ujson.encode(input)
self.assertEqual(round(input, 5), round(json.loads(output), 5))
self.assertEqual(round(input, 5), round(ujson.decode(output), 5))
def test_encodeWithDecimal(self):
input = 1.0
output = ujson.encode(input)
self.assertEqual(output, "1.0")
def test_encodeDoubleNegConversion(self):
input = -math.pi
output = ujson.encode(input)
self.assertEqual(round(input, 5), round(json.loads(output), 5))
self.assertEqual(round(input, 5), round(ujson.decode(output), 5))
def test_encodeArrayOfNestedArrays(self):
input = [[[[]]]] * 20
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
# self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeArrayOfDoubles(self):
input = [31337.31337, 31337.31337, 31337.31337, 31337.31337] * 10
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
# self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeStringConversion2(self):
input = "A string \\ / \b \f \n \r \t"
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, '"A string \\\\ \\/ \\b \\f \\n \\r \\t"')
self.assertEqual(input, ujson.decode(output))
def test_decodeUnicodeConversion(self):
pass
def test_encodeUnicodeConversion1(self):
input = "Räksmörgås اسامة بن محمد بن عوض بن لادن"
enc = ujson.encode(input)
dec = ujson.decode(enc)
self.assertEqual(enc, json_unicode(input))
self.assertEqual(dec, json.loads(enc))
def test_encodeControlEscaping(self):
input = "\x19"
enc = ujson.encode(input)
dec = ujson.decode(enc)
self.assertEqual(input, dec)
self.assertEqual(enc, json_unicode(input))
def test_encodeUnicodeConversion2(self):
input = "\xe6\x97\xa5\xd1\x88"
enc = ujson.encode(input)
dec = ujson.decode(enc)
self.assertEqual(enc, json_unicode(input))
self.assertEqual(dec, json.loads(enc))
def test_encodeUnicodeSurrogatePair(self):
input = "\xf0\x90\x8d\x86"
enc = ujson.encode(input)
dec = ujson.decode(enc)
self.assertEqual(enc, json_unicode(input))
self.assertEqual(dec, json.loads(enc))
def test_encodeUnicode4BytesUTF8(self):
input = "\xf0\x91\x80\xb0TRAILINGNORMAL"
enc = ujson.encode(input)
dec = ujson.decode(enc)
self.assertEqual(enc, json_unicode(input))
self.assertEqual(dec, json.loads(enc))
def test_encodeUnicode4BytesUTF8Highest(self):
input = "\xf3\xbf\xbf\xbfTRAILINGNORMAL"
enc = ujson.encode(input)
dec = ujson.decode(enc)
self.assertEqual(enc, json_unicode(input))
self.assertEqual(dec, json.loads(enc))
# Characters outside of Basic Multilingual Plane(larger than
# 16 bits) are represented as \UXXXXXXXX in python but should be encoded
# as \uXXXX\uXXXX in json.
def testEncodeUnicodeBMP(self):
s = "\U0001f42e\U0001f42e\U0001F42D\U0001F42D" # 🐮🐮🐭🐭
encoded = ujson.dumps(s)
encoded_json = json.dumps(s)
if len(s) == 4:
self.assertEqual(len(encoded), len(s) * 12 + 2)
else:
self.assertEqual(len(encoded), len(s) * 6 + 2)
self.assertEqual(encoded, encoded_json)
decoded = ujson.loads(encoded)
self.assertEqual(s, decoded)
# ujson outputs an UTF-8 encoded str object
encoded = ujson.dumps(s, ensure_ascii=False)
# json outputs an unicode object
encoded_json = json.dumps(s, ensure_ascii=False)
self.assertEqual(len(encoded), len(s) + 2) # original length + quotes
self.assertEqual(encoded, encoded_json)
decoded = ujson.loads(encoded)
self.assertEqual(s, decoded)
def testEncodeSymbols(self):
s = "\u273f\u2661\u273f" # ✿♡✿
encoded = ujson.dumps(s)
encoded_json = json.dumps(s)
self.assertEqual(len(encoded), len(s) * 6 + 2) # 6 characters + quotes
self.assertEqual(encoded, encoded_json)
decoded = ujson.loads(encoded)
self.assertEqual(s, decoded)
# ujson outputs an UTF-8 encoded str object
encoded = ujson.dumps(s, ensure_ascii=False)
# json outputs an unicode object
encoded_json = json.dumps(s, ensure_ascii=False)
self.assertEqual(len(encoded), len(s) + 2) # original length + quotes
self.assertEqual(encoded, encoded_json)
decoded = ujson.loads(encoded)
self.assertEqual(s, decoded)
def test_encodeArrayInArray(self):
input = [[[[]]]]
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeIntConversion(self):
input = 31337
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeIntNegConversion(self):
input = -31337
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeLongNegConversion(self):
input = -9223372036854775808
output = ujson.encode(input)
json.loads(output)
ujson.decode(output)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeListConversion(self):
input = [1, 2, 3, 4]
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(input, ujson.decode(output))
def test_encodeDictConversion(self):
input = {"k1": 1, "k2": 2, "k3": 3, "k4": 4}
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(input, ujson.decode(output))
self.assertEqual(input, ujson.decode(output))
def test_encodeNoneConversion(self):
input = None
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeTrueConversion(self):
input = True
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeFalseConversion(self):
input = False
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeToUTF8(self):
input = b"\xe6\x97\xa5\xd1\x88"
input = input.decode("utf-8")
enc = ujson.encode(input, ensure_ascii=False)
dec = ujson.decode(enc)
self.assertEqual(enc, json.dumps(input, ensure_ascii=False))
self.assertEqual(dec, json.loads(enc))
def test_decodeFromUnicode(self):
input = '{"obj": 31337}'
dec1 = ujson.decode(input)
dec2 = ujson.decode(str(input))
self.assertEqual(dec1, dec2)
def test_encodeRecursionMax(self):
# 8 is the max recursion depth
class O2:
member = 0
def toDict(self):
return {"member": self.member}
class O1:
member = 0
def toDict(self):
return {"member": self.member}
input = O1()
input.member = O2()
input.member.member = input
self.assertRaises(OverflowError, ujson.encode, input)
def test_encodeDoubleNan(self):
input = float("nan")
self.assertRaises(OverflowError, ujson.encode, input)
def test_encodeDoubleInf(self):
input = float("inf")
self.assertRaises(OverflowError, ujson.encode, input)
def test_encodeDoubleNegInf(self):
input = -float("inf")
self.assertRaises(OverflowError, ujson.encode, input)
def test_encodeOrderedDict(self):
from collections import OrderedDict
input = OrderedDict([(1, 1), (0, 0), (8, 8), (2, 2)])
self.assertEqual('{"1":1,"0":0,"8":8,"2":2}', ujson.encode(input))
def test_decodeJibberish(self):
input = "fdsa sda v9sa fdsa"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeBrokenArrayStart(self):
input = "["
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeBrokenObjectStart(self):
input = "{"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeBrokenArrayEnd(self):
input = "]"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeArrayDepthTooBig(self):
input = "[" * (1024 * 1024)
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeBrokenObjectEnd(self):
input = "}"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeObjectTrailingCommaFail(self):
input = '{"one":1,}'
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeObjectDepthTooBig(self):
input = "{" * (1024 * 1024)
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeStringUnterminated(self):
input = '"TESTING'
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeStringUntermEscapeSequence(self):
input = '"TESTING\\"'
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeStringBadEscape(self):
input = '"TESTING\\"'
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeTrueBroken(self):
input = "tru"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeFalseBroken(self):
input = "fa"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeNullBroken(self):
input = "n"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeBrokenDictKeyTypeLeakTest(self):
input = '{{1337:""}}'
for x in range(1000):
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeBrokenDictLeakTest(self):
input = '{{"key":"}'
for x in range(1000):
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeBrokenListLeakTest(self):
input = "[[[true"
for x in range(1000):
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeDictWithNoKey(self):
input = "{{{{31337}}}}"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeDictWithNoColonOrValue(self):
input = '{{{{"key"}}}}'
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeDictWithNoValue(self):
input = '{{{{"key":}}}}'
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeNumericIntPos(self):
input = "31337"
self.assertEqual(31337, ujson.decode(input))
def test_decodeNumericIntNeg(self):
input = "-31337"
self.assertEqual(-31337, ujson.decode(input))
def test_encodeUnicode4BytesUTF8Fail(self):
input = b"\xfd\xbf\xbf\xbf\xbf\xbf"
self.assertRaises(OverflowError, ujson.encode, input)
def test_encodeNullCharacter(self):
input = "31337 \x00 1337"
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
input = "\x00"
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
self.assertEqual('" \\u0000\\r\\n "', ujson.dumps(" \u0000\r\n "))
def test_decodeNullCharacter(self):
input = '"31337 \\u0000 31337"'
self.assertEqual(ujson.decode(input), json.loads(input))
def test_encodeListLongConversion(self):
input = [
9223372036854775807,
9223372036854775807,
9223372036854775807,
9223372036854775807,
9223372036854775807,
9223372036854775807,
]
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(input, ujson.decode(output))
def test_encodeListLongUnsignedConversion(self):
input = [18446744073709551615, 18446744073709551615, 18446744073709551615]
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(input, ujson.decode(output))
def test_encodeLongConversion(self):
input = 9223372036854775807
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_encodeLongUnsignedConversion(self):
input = 18446744073709551615
output = ujson.encode(input)
self.assertEqual(input, json.loads(output))
self.assertEqual(output, json.dumps(input))
self.assertEqual(input, ujson.decode(output))
def test_numericIntExp(self):
input = "1337E40"
output = ujson.decode(input)
self.assertEqual(output, json.loads(input))
def test_numericIntFrcExp(self):
input = "1.337E40"
output = ujson.decode(input)
self.assertEqual(output, json.loads(input))
def test_decodeNumericIntExpEPLUS(self):
input = "1337E+9"
output = ujson.decode(input)
self.assertEqual(output, json.loads(input))
def test_decodeNumericIntExpePLUS(self):
input = "1.337e+40"
output = ujson.decode(input)
self.assertEqual(output, json.loads(input))
def test_decodeNumericIntExpE(self):
input = "1337E40"
output = ujson.decode(input)
self.assertEqual(output, json.loads(input))
def test_decodeNumericIntExpe(self):
input = "1337e40"
output = ujson.decode(input)
self.assertEqual(output, json.loads(input))
def test_decodeNumericIntExpEMinus(self):
input = "1.337E-4"
output = ujson.decode(input)
self.assertEqual(output, json.loads(input))
def test_decodeNumericIntExpeMinus(self):
input = "1.337e-4"
output = ujson.decode(input)
self.assertEqual(output, json.loads(input))
def test_dumpToFile(self):
f = StringIO()
ujson.dump([1, 2, 3], f)
self.assertEqual("[1,2,3]", f.getvalue())
def test_dumpToFileLikeObject(self):
class filelike:
def __init__(self):
self.bytes = ""
def write(self, bytes):
self.bytes += bytes
f = filelike()
ujson.dump([1, 2, 3], f)
self.assertEqual("[1,2,3]", f.bytes)
def test_dumpFileArgsError(self):
self.assertRaises(TypeError, ujson.dump, [], "")
def test_loadFile(self):
f = StringIO("[1,2,3,4]")
self.assertEqual([1, 2, 3, 4], ujson.load(f))
def test_loadFileLikeObject(self):
class filelike:
def read(self):
try:
self.end
except AttributeError:
self.end = True
return "[1,2,3,4]"
f = filelike()
self.assertEqual([1, 2, 3, 4], ujson.load(f))
def test_loadFileArgsError(self):
self.assertRaises(TypeError, ujson.load, "[]")
def test_encodeNumericOverflow(self):
self.assertRaises(OverflowError, ujson.encode, 12839128391289382193812939)
def test_decodeNumberWith32bitSignBit(self):
# Test that numbers that fit within 32 bits but would have the
# sign bit set (2**31 <= x < 2**32) are decoded properly.
docs = (
'{"id": 3590016419}',
'{"id": %s}' % 2 ** 31,
'{"id": %s}' % 2 ** 32,
'{"id": %s}' % ((2 ** 32) - 1),
)
results = (3590016419, 2 ** 31, 2 ** 32, 2 ** 32 - 1)
for doc, result in zip(docs, results):
self.assertEqual(ujson.decode(doc)["id"], result)
def test_encodeBigEscape(self):
for x in range(10):
base = "\u00e5".encode("utf-8")
input = base * 1024 * 1024 * 2
ujson.encode(input)
def test_decodeBigEscape(self):
for x in range(10):
base = "\u00e5".encode("utf-8")
quote = '"'.encode()
input = quote + (base * 1024 * 1024 * 2) + quote
ujson.decode(input)
def test_toDict(self):
d = {"key": 31337}
class DictTest:
def toDict(self):
return d
def __json__(self):
return '"json defined"' # Fallback and shouldn't be called.
o = DictTest()
output = ujson.encode(o)
dec = ujson.decode(output)
self.assertEqual(dec, d)
def test_object_with_json(self):
# If __json__ returns a string, then that string
# will be used as a raw JSON snippet in the object.
output_text = "this is the correct output"
class JSONTest:
def __json__(self):
return '"' + output_text + '"'
d = {u"key": JSONTest()}
output = ujson.encode(d)
dec = ujson.decode(output)
self.assertEqual(dec, {u"key": output_text})
def test_object_with_json_unicode(self):
# If __json__ returns a string, then that string
# will be used as a raw JSON snippet in the object.
output_text = u"this is the correct output"
class JSONTest:
def __json__(self):
return u'"' + output_text + u'"'
d = {u"key": JSONTest()}
output = ujson.encode(d)
dec = ujson.decode(output)
self.assertEqual(dec, {u"key": output_text})
def test_object_with_complex_json(self):
# If __json__ returns a string, then that string
# will be used as a raw JSON snippet in the object.
obj = {u"foo": [u"bar", u"baz"]}
class JSONTest:
def __json__(self):
return ujson.encode(obj)
d = {u"key": JSONTest()}
output = ujson.encode(d)
dec = ujson.decode(output)
self.assertEqual(dec, {u"key": obj})
def test_object_with_json_type_error(self):
# __json__ must return a string, otherwise it should raise an error.
for return_value in (None, 1234, 12.34, True, {}):
class JSONTest:
def __json__(self):
return return_value
d = {u"key": JSONTest()}
self.assertRaises(TypeError, ujson.encode, d)
def test_object_with_json_attribute_error(self):
# If __json__ raises an error, make sure python actually raises it.
class JSONTest:
def __json__(self):
raise AttributeError
d = {u"key": JSONTest()}
self.assertRaises(AttributeError, ujson.encode, d)
def test_decodeArrayTrailingCommaFail(self):
input = "[31337,]"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeArrayLeadingCommaFail(self):
input = "[,31337]"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeArrayOnlyCommaFail(self):
input = "[,]"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeArrayUnmatchedBracketFail(self):
input = "[]]"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeArrayEmpty(self):
input = "[]"
obj = ujson.decode(input)
self.assertEqual([], obj)
def test_decodeArrayOneItem(self):
input = "[31337]"
ujson.decode(input)
def test_decodeLongUnsignedValue(self):
input = "18446744073709551615"
ujson.decode(input)
def test_decodeBigValue(self):
input = "9223372036854775807"
ujson.decode(input)
def test_decodeSmallValue(self):
input = "-9223372036854775808"
ujson.decode(input)
def test_decodeTooBigValue(self):
input = "18446744073709551616"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeTooSmallValue(self):
input = "-90223372036854775809"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeVeryTooBigValue(self):
input = "18446744073709551616"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeVeryTooSmallValue(self):
input = "-90223372036854775809"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeWithTrailingWhitespaces(self):
input = "{}\n\t "
ujson.decode(input)
def test_decodeWithTrailingNonWhitespaces(self):
input = "{}\n\t a"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeArrayWithBigInt(self):
input = "[18446744073709551616]"
self.assertRaises(ValueError, ujson.decode, input)
def test_decodeFloatingPointAdditionalTests(self):
self.assertAlmostEqual(-1.1234567893, ujson.loads("-1.1234567893"))
self.assertAlmostEqual(-1.234567893, ujson.loads("-1.234567893"))
self.assertAlmostEqual(-1.34567893, ujson.loads("-1.34567893"))
self.assertAlmostEqual(-1.4567893, ujson.loads("-1.4567893"))
self.assertAlmostEqual(-1.567893, ujson.loads("-1.567893"))
self.assertAlmostEqual(-1.67893, ujson.loads("-1.67893"))
self.assertAlmostEqual(-1.7894, ujson.loads("-1.7894"))
self.assertAlmostEqual(-1.893, ujson.loads("-1.893"))
self.assertAlmostEqual(-1.3, ujson.loads("-1.3"))
self.assertAlmostEqual(1.1234567893, ujson.loads("1.1234567893"))
self.assertAlmostEqual(1.234567893, ujson.loads("1.234567893"))
self.assertAlmostEqual(1.34567893, ujson.loads("1.34567893"))
self.assertAlmostEqual(1.4567893, ujson.loads("1.4567893"))
self.assertAlmostEqual(1.567893, ujson.loads("1.567893"))
self.assertAlmostEqual(1.67893, ujson.loads("1.67893"))
self.assertAlmostEqual(1.7894, ujson.loads("1.7894"))
self.assertAlmostEqual(1.893, ujson.loads("1.893"))
self.assertAlmostEqual(1.3, ujson.loads("1.3"))
def test_ReadBadObjectSyntax(self):
input = '{"age", 44}'
self.assertRaises(ValueError, ujson.decode, input)
def test_ReadTrue(self):
self.assertEqual(True, ujson.loads("true"))
def test_ReadFalse(self):
self.assertEqual(False, ujson.loads("false"))
def test_ReadNull(self):
self.assertEqual(None, ujson.loads("null"))
def test_WriteTrue(self):
self.assertEqual("true", ujson.dumps(True))
def test_WriteFalse(self):
self.assertEqual("false", ujson.dumps(False))
def test_WriteNull(self):
self.assertEqual("null", ujson.dumps(None))
def test_ReadArrayOfSymbols(self):
self.assertEqual([True, False, None], ujson.loads(" [ true, false,null] "))
def test_WriteArrayOfSymbolsFromList(self):
self.assertEqual("[true,false,null]", ujson.dumps([True, False, None]))
def test_WriteArrayOfSymbolsFromTuple(self):
self.assertEqual("[true,false,null]", ujson.dumps((True, False, None)))
def test_encodingInvalidUnicodeCharacter(self):
s = "\udc7f"
self.assertRaises(UnicodeEncodeError, ujson.dumps, s)
def test_sortKeys(self):
data = {"a": 1, "c": 1, "b": 1, "e": 1, "f": 1, "d": 1}
sortedKeys = ujson.dumps(data, sort_keys=True)
self.assertEqual(sortedKeys, '{"a":1,"b":1,"c":1,"d":1,"e":1,"f":1}')
@unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
def test_does_not_leak_dictionary_values(self):
import gc
gc.collect()
value = ["abc"]
data = {"1": value}
ref_count = sys.getrefcount(value)
ujson.dumps(data)
self.assertEqual(ref_count, sys.getrefcount(value))
@unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
def test_does_not_leak_dictionary_keys(self):
import gc
gc.collect()
key1 = "1"
key2 = "1"
value1 = ["abc"]
value2 = [1, 2, 3]
data = {key1: value1, key2: value2}
ref_count1 = sys.getrefcount(key1)
ref_count2 = sys.getrefcount(key2)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))
self.assertEqual(ref_count2, sys.getrefcount(key2))
@unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
def test_does_not_leak_dictionary_string_key(self):
import gc
gc.collect()
key1 = "1"
value1 = 1
data = {key1: value1}
ref_count1 = sys.getrefcount(key1)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))
@unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
def test_does_not_leak_dictionary_tuple_key(self):
import gc
gc.collect()
key1 = ("a",)
value1 = 1
data = {key1: value1}
ref_count1 = sys.getrefcount(key1)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))
@unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
def test_does_not_leak_dictionary_bytes_key(self):
import gc
gc.collect()
key1 = b"1"
value1 = 1
data = {key1: value1}
ref_count1 = sys.getrefcount(key1)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))
@unittest.skipIf(not hasattr(sys, 'getrefcount') == True, reason="test requires sys.refcount")
def test_does_not_leak_dictionary_None_key(self):
import gc
gc.collect()
key1 = None
value1 = 1
data = {key1: value1}
ref_count1 = sys.getrefcount(key1)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))
"""
def test_decodeNumericIntFrcOverflow(self):
input = "X.Y"
raise NotImplementedError("Implement this test!")
def test_decodeStringUnicodeEscape(self):
input = "\u3131"
raise NotImplementedError("Implement this test!")
def test_decodeStringUnicodeBrokenEscape(self):
input = "\u3131"
raise NotImplementedError("Implement this test!")
def test_decodeStringUnicodeInvalidEscape(self):
input = "\u3131"
raise NotImplementedError("Implement this test!")
def test_decodeStringUTF8(self):
input = "someutfcharacters"
raise NotImplementedError("Implement this test!")
"""
if __name__ == "__main__":
unittest.main()
"""
# Use this to look for memory leaks
if __name__ == '__main__':
from guppy import hpy
hp = hpy()
hp.setrelheap()
while True:
try:
unittest.main()
except SystemExit:
pass
heap = hp.heapu()
print(heap)
"""
@pytest.mark.parametrize("indent", list(range(65537, 65542)))
def test_dump_huge_indent(indent):
ujson.encode({"a": True}, indent=indent)
@pytest.mark.parametrize("first_length", list(range(2, 7)))
@pytest.mark.parametrize("second_length", list(range(10919, 10924)))
def test_dump_long_string(first_length, second_length):
ujson.dumps(["a" * first_length, "\x00" * second_length])
def test_dump_indented_nested_list():
a = _a = []
for i in range(20):
_a.append(list(range(i)))
_a = _a[-1]
ujson.dumps(a, indent=i)
@pytest.mark.parametrize("indent", [0, 1, 2, 4, 5, 8, 49])
def test_issue_334(indent):
path = Path(__file__).with_name("334-reproducer.json")
a = ujson.loads(path.read_bytes())
ujson.dumps(a, indent=indent)
@pytest.mark.parametrize(
"test_input, expected",
[
# Normal cases
(r'"\uD83D\uDCA9"', "\U0001F4A9"),
(r'"a\uD83D\uDCA9b"', "a\U0001F4A9b"),
# Unpaired surrogates
(r'"\uD800"', "\uD800"),
(r'"a\uD800b"', "a\uD800b"),
(r'"\uDEAD"', "\uDEAD"),
(r'"a\uDEADb"', "a\uDEADb"),
(r'"\uD83D\uD83D\uDCA9"', "\uD83D\U0001F4A9"),
(r'"\uDCA9\uD83D\uDCA9"', "\uDCA9\U0001F4A9"),
(r'"\uD83D\uDCA9\uD83D"', "\U0001F4A9\uD83D"),
(r'"\uD83D\uDCA9\uDCA9"', "\U0001F4A9\uDCA9"),
(r'"\uD83D \uDCA9"', "\uD83D \uDCA9"),
# No decoding of actual surrogate characters (rather than escaped ones)
('"\uD800"', "\uD800"),
('"\uDEAD"', "\uDEAD"),
('"\uD800a\uDEAD"', "\uD800a\uDEAD"),
('"\uD83D\uDCA9"', "\uD83D\uDCA9"),
],
)
def test_decode_surrogate_characters(test_input, expected):
assert ujson.loads(test_input) == expected
assert ujson.loads(test_input.encode("utf-8", "surrogatepass")) == expected
# Ensure that this matches stdlib's behaviour
assert json.loads(test_input) == expected
srsly-release-v2.5.1/srsly/tests/util.py 0000664 0000000 0000000 00000000637 14742310675 0020363 0 ustar 00root root 0000000 0000000 import tempfile
from pathlib import Path
from contextlib import contextmanager
import shutil
@contextmanager
def make_tempdir(files={}, mode="w"):
temp_dir_str = tempfile.mkdtemp()
temp_dir = Path(temp_dir_str)
for name, content in files.items():
path = temp_dir / name
with path.open(mode) as file_:
file_.write(content)
yield temp_dir
shutil.rmtree(temp_dir_str)
srsly-release-v2.5.1/srsly/ujson/ 0000775 0000000 0000000 00000000000 14742310675 0017022 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/ujson/JSONtoObj.c 0000664 0000000 0000000 00000014331 14742310675 0020737 0 ustar 00root root 0000000 0000000 /*
Developed by ESN, an Electronic Arts Inc. studio.
Copyright (c) 2014, Electronic Arts Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of ESN, Electronic Arts Inc. nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
http://code.google.com/p/stringencoders/
Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
Numeric decoder derived from from TCL library
http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
* Copyright (c) 1988-1993 The Regents of the University of California.
* Copyright (c) 1994 Sun Microsystems, Inc.
*/
#include "py_defines.h"
#include
//#define PRINTMARK() fprintf(stderr, "%s: MARK(%d)\n", __FILE__, __LINE__)
#define PRINTMARK()
void Object_objectAddKey(void *prv, JSOBJ obj, JSOBJ name, JSOBJ value)
{
PyDict_SetItem (obj, name, value);
Py_DECREF( (PyObject *) name);
Py_DECREF( (PyObject *) value);
return;
}
void Object_arrayAddItem(void *prv, JSOBJ obj, JSOBJ value)
{
PyList_Append(obj, value);
Py_DECREF( (PyObject *) value);
return;
}
/*
Check that Py_UCS4 is the same as JSUINT32, else Object_newString will fail.
Based on Linux's check in vbox_vmmdev_types.h.
This should be replaced with
_Static_assert(sizeof(Py_UCS4) == sizeof(JSUINT32));
when C11 is made mandatory (CPython 3.11+, PyPy ?).
*/
typedef char assert_py_ucs4_is_jsuint32[1 - 2*!(sizeof(Py_UCS4) == sizeof(JSUINT32))];
static JSOBJ Object_newString(void *prv, JSUINT32 *start, JSUINT32 *end)
{
return PyUnicode_FromKindAndData (PyUnicode_4BYTE_KIND, (Py_UCS4 *) start, (end - start));
}
JSOBJ Object_newTrue(void *prv)
{
Py_RETURN_TRUE;
}
JSOBJ Object_newFalse(void *prv)
{
Py_RETURN_FALSE;
}
JSOBJ Object_newNull(void *prv)
{
Py_RETURN_NONE;
}
JSOBJ Object_newObject(void *prv)
{
return PyDict_New();
}
JSOBJ Object_newArray(void *prv)
{
return PyList_New(0);
}
JSOBJ Object_newInteger(void *prv, JSINT32 value)
{
return PyInt_FromLong( (long) value);
}
JSOBJ Object_newLong(void *prv, JSINT64 value)
{
return PyLong_FromLongLong (value);
}
JSOBJ Object_newUnsignedLong(void *prv, JSUINT64 value)
{
return PyLong_FromUnsignedLongLong (value);
}
JSOBJ Object_newDouble(void *prv, double value)
{
return PyFloat_FromDouble(value);
}
static void Object_releaseObject(void *prv, JSOBJ obj)
{
Py_DECREF( ((PyObject *)obj));
}
static char *g_kwlist[] = {"obj", "precise_float", NULL};
PyObject* JSONToObj(PyObject* self, PyObject *args, PyObject *kwargs)
{
PyObject *ret;
PyObject *sarg;
PyObject *arg;
PyObject *opreciseFloat = NULL;
JSONObjectDecoder decoder =
{
Object_newString,
Object_objectAddKey,
Object_arrayAddItem,
Object_newTrue,
Object_newFalse,
Object_newNull,
Object_newObject,
Object_newArray,
Object_newInteger,
Object_newLong,
Object_newUnsignedLong,
Object_newDouble,
Object_releaseObject,
PyObject_Malloc,
PyObject_Free,
PyObject_Realloc
};
decoder.preciseFloat = 0;
decoder.prv = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O", g_kwlist, &arg, &opreciseFloat))
{
return NULL;
}
if (opreciseFloat && PyObject_IsTrue(opreciseFloat))
{
decoder.preciseFloat = 1;
}
if (PyString_Check(arg))
{
sarg = arg;
}
else
if (PyUnicode_Check(arg))
{
sarg = PyUnicode_AsEncodedString(arg, NULL, "surrogatepass");
if (sarg == NULL)
{
//Exception raised above us by codec according to docs
return NULL;
}
}
else
{
PyErr_Format(PyExc_TypeError, "Expected String or Unicode");
return NULL;
}
decoder.errorStr = NULL;
decoder.errorOffset = NULL;
ret = JSON_DecodeObject(&decoder, PyString_AS_STRING(sarg), PyString_GET_SIZE(sarg));
if (sarg != arg)
{
Py_DECREF(sarg);
}
if (decoder.errorStr)
{
/*
FIXME: It's possible to give a much nicer error message here with actual failing element in input etc*/
PyErr_Format (PyExc_ValueError, "%s", decoder.errorStr);
if (ret)
{
Py_DECREF( (PyObject *) ret);
}
return NULL;
}
return ret;
}
PyObject* JSONFileToObj(PyObject* self, PyObject *args, PyObject *kwargs)
{
PyObject *read;
PyObject *string;
PyObject *result;
PyObject *file = NULL;
PyObject *argtuple;
if (!PyArg_ParseTuple (args, "O", &file))
{
return NULL;
}
if (!PyObject_HasAttrString (file, "read"))
{
PyErr_Format (PyExc_TypeError, "expected file");
return NULL;
}
read = PyObject_GetAttrString (file, "read");
if (!PyCallable_Check (read)) {
Py_XDECREF(read);
PyErr_Format (PyExc_TypeError, "expected file");
return NULL;
}
string = PyObject_CallObject (read, NULL);
Py_XDECREF(read);
if (string == NULL)
{
return NULL;
}
argtuple = PyTuple_Pack(1, string);
result = JSONToObj (self, argtuple, kwargs);
Py_XDECREF(argtuple);
Py_XDECREF(string);
if (result == NULL) {
return NULL;
}
return result;
}
srsly-release-v2.5.1/srsly/ujson/__init__.py 0000664 0000000 0000000 00000000112 14742310675 0021125 0 ustar 00root root 0000000 0000000 from .ujson import decode, encode, dump, dumps, load, loads # noqa: F401
srsly-release-v2.5.1/srsly/ujson/lib/ 0000775 0000000 0000000 00000000000 14742310675 0017570 5 ustar 00root root 0000000 0000000 srsly-release-v2.5.1/srsly/ujson/lib/dconv_wrapper.cc 0000664 0000000 0000000 00000003766 14742310675 0022764 0 ustar 00root root 0000000 0000000 #include "double-conversion.h"
namespace double_conversion
{
static StringToDoubleConverter* s2d_instance = NULL;
static DoubleToStringConverter* d2s_instance = NULL;
extern "C"
{
void dconv_d2s_init(int flags,
const char* infinity_symbol,
const char* nan_symbol,
char exponent_character,
int decimal_in_shortest_low,
int decimal_in_shortest_high,
int max_leading_padding_zeroes_in_precision_mode,
int max_trailing_padding_zeroes_in_precision_mode)
{
d2s_instance = new DoubleToStringConverter(flags, infinity_symbol, nan_symbol,
exponent_character, decimal_in_shortest_low,
decimal_in_shortest_high, max_leading_padding_zeroes_in_precision_mode,
max_trailing_padding_zeroes_in_precision_mode);
}
int dconv_d2s(double value, char* buf, int buflen, int* strlength)
{
StringBuilder sb(buf, buflen);
int success = static_cast(d2s_instance->ToShortest(value, &sb));
*strlength = success ? sb.position() : -1;
return success;
}
void dconv_d2s_free()
{
delete d2s_instance;
d2s_instance = NULL;
}
void dconv_s2d_init(int flags, double empty_string_value,
double junk_string_value, const char* infinity_symbol,
const char* nan_symbol)
{
s2d_instance = new StringToDoubleConverter(flags, empty_string_value,
junk_string_value, infinity_symbol, nan_symbol);
}
double dconv_s2d(const char* buffer, int length, int* processed_characters_count)
{
return s2d_instance->StringToDouble(buffer, length, processed_characters_count);
}
void dconv_s2d_free()
{
delete s2d_instance;
s2d_instance = NULL;
}
}
}
srsly-release-v2.5.1/srsly/ujson/lib/ultrajson.h 0000664 0000000 0000000 00000023100 14742310675 0021756 0 ustar 00root root 0000000 0000000 /*
Developed by ESN, an Electronic Arts Inc. studio.
Copyright (c) 2014, Electronic Arts Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of ESN, Electronic Arts Inc. nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
http://code.google.com/p/stringencoders/
Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
Numeric decoder derived from from TCL library
http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
* Copyright (c) 1988-1993 The Regents of the University of California.
* Copyright (c) 1994 Sun Microsystems, Inc.
*/
/*
Ultra fast JSON encoder and decoder
Developed by Jonas Tarnstrom (jonas@esn.me).
Encoder notes:
------------------
:: Cyclic references ::
Cyclic referenced objects are not detected.
Set JSONObjectEncoder.recursionMax to suitable value or make sure input object
tree doesn't have cyclic references.
*/
#ifndef __ULTRAJSON_H__
#define __ULTRAJSON_H__
#include
// Max decimals to encode double floating point numbers with
#ifndef JSON_DOUBLE_MAX_DECIMALS
#define JSON_DOUBLE_MAX_DECIMALS 15
#endif
// Max recursion depth, default for encoder
#ifndef JSON_MAX_RECURSION_DEPTH
#define JSON_MAX_RECURSION_DEPTH 1024
#endif
// Max recursion depth, default for decoder
#ifndef JSON_MAX_OBJECT_DEPTH
#define JSON_MAX_OBJECT_DEPTH 1024
#endif
/*
Dictates and limits how much stack space for buffers UltraJSON will use before resorting to provided heap functions */
#ifndef JSON_MAX_STACK_BUFFER_SIZE
#define JSON_MAX_STACK_BUFFER_SIZE 131072
#endif
#ifdef _WIN32
typedef __int64 JSINT64;
typedef unsigned __int64 JSUINT64;
typedef __int32 JSINT32;
typedef unsigned __int32 JSUINT32;
typedef unsigned __int8 JSUINT8;
typedef unsigned __int16 JSUTF16;
typedef unsigned __int32 JSUTF32;
typedef __int64 JSLONG;
#define EXPORTFUNCTION __declspec(dllexport)
#define FASTCALL_MSVC __fastcall
#define FASTCALL_ATTR
#define INLINE_PREFIX __inline
#else
#include
typedef int64_t JSINT64;
typedef uint64_t JSUINT64;
typedef int32_t JSINT32;
typedef uint32_t JSUINT32;
#define FASTCALL_MSVC
#if !defined __x86_64__
#define FASTCALL_ATTR __attribute__((fastcall))
#else
#define FASTCALL_ATTR
#endif
#define INLINE_PREFIX inline
typedef uint8_t JSUINT8;
typedef uint16_t JSUTF16;
typedef uint32_t JSUTF32;
typedef int64_t JSLONG;
#define EXPORTFUNCTION
#endif
#if !(defined(__LITTLE_ENDIAN__) || defined(__BIG_ENDIAN__))
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
#define __LITTLE_ENDIAN__
#else
#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
#define __BIG_ENDIAN__
#endif
#endif
#endif
#if !defined(__LITTLE_ENDIAN__) && !defined(__BIG_ENDIAN__)
#error "Endianess not supported"
#endif
enum JSTYPES
{
JT_NULL, // NULL
JT_TRUE, // boolean true
JT_FALSE, // boolean false
JT_INT, // (JSINT32 (signed 32-bit))
JT_LONG, // (JSINT64 (signed 64-bit))
JT_ULONG, // (JSUINT64 (unsigned 64-bit))
JT_DOUBLE, // (double)
JT_UTF8, // (char 8-bit)
JT_RAW, // (raw char 8-bit)
JT_ARRAY, // Array structure
JT_OBJECT, // Key/Value structure
JT_INVALID, // Internal, do not return nor expect
};
typedef void * JSOBJ;
typedef void * JSITER;
typedef struct __JSONTypeContext
{
int type;
void *prv;
void *encoder_prv;
} JSONTypeContext;
/*
Function pointer declarations, suitable for implementing UltraJSON */
typedef int (*JSPFN_ITERNEXT)(JSOBJ obj, JSONTypeContext *tc);
typedef void (*JSPFN_ITEREND)(JSOBJ obj, JSONTypeContext *tc);
typedef JSOBJ (*JSPFN_ITERGETVALUE)(JSOBJ obj, JSONTypeContext *tc);
typedef char *(*JSPFN_ITERGETNAME)(JSOBJ obj, JSONTypeContext *tc, size_t *outLen);
typedef void *(*JSPFN_MALLOC)(size_t size);
typedef void (*JSPFN_FREE)(void *pptr);
typedef void *(*JSPFN_REALLOC)(void *base, size_t size);
struct __JSONObjectEncoder;
typedef struct __JSONObjectEncoder
{
void (*beginTypeContext)(JSOBJ obj, JSONTypeContext *tc, struct __JSONObjectEncoder *enc);
void (*endTypeContext)(JSOBJ obj, JSONTypeContext *tc);
const char *(*getStringValue)(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen);
JSINT64 (*getLongValue)(JSOBJ obj, JSONTypeContext *tc);
JSUINT64 (*getUnsignedLongValue)(JSOBJ obj, JSONTypeContext *tc);
JSINT32 (*getIntValue)(JSOBJ obj, JSONTypeContext *tc);
double (*getDoubleValue)(JSOBJ obj, JSONTypeContext *tc);
/*
Retrieve next object in an iteration. Should return 0 to indicate iteration has reached end or 1 if there are more items.
Implementor is responsible for keeping state of the iteration. Use ti->prv fields for this
*/
JSPFN_ITERNEXT iterNext;
/*
Ends the iteration of an iteratable object.
Any iteration state stored in ti->prv can be freed here
*/
JSPFN_ITEREND iterEnd;
/*
Returns a reference to the value object of an iterator
The is responsible for the life-cycle of the returned string. Use iterNext/iterEnd and ti->prv to keep track of current object
*/
JSPFN_ITERGETVALUE iterGetValue;
/*
Return name of iterator.
The is responsible for the life-cycle of the returned string. Use iterNext/iterEnd and ti->prv to keep track of current object
*/
JSPFN_ITERGETNAME iterGetName;
/*
Release a value as indicated by setting ti->release = 1 in the previous getValue call.
The ti->prv array should contain the necessary context to release the value
*/
void (*releaseObject)(JSOBJ obj);
/* Library functions
Set to NULL to use STDLIB malloc,realloc,free */
JSPFN_MALLOC malloc;
JSPFN_REALLOC realloc;
JSPFN_FREE free;
/*
Configuration for max recursion, set to 0 to use default (see JSON_MAX_RECURSION_DEPTH)*/
int recursionMax;
/*
Configuration for max decimals of double floating point numbers to encode (0-9) */
int doublePrecision;
/*
If true output will be ASCII with all characters above 127 encoded as \uXXXX. If false output will be UTF-8 or what ever charset strings are brought as */
int forceASCII;
/*
If true, '<', '>', and '&' characters will be encoded as \u003c, \u003e, and \u0026, respectively. If false, no special encoding will be used. */
int encodeHTMLChars;
/*
If true, '/' will be encoded as \/. If false, no escaping. */
int escapeForwardSlashes;
/*
If true, dictionaries are iterated through in sorted key order. */
int sortKeys;
/*
Configuration for spaces of indent */
int indent;
/*
Private pointer to be used by the caller. Passed as encoder_prv in JSONTypeContext */
void *prv;
/*
Set to an error message if error occured */
const char *errorMsg;
JSOBJ errorObj;
/* Buffer stuff */
char *start;
char *offset;
char *end;
int heap;
int level;
} JSONObjectEncoder;
/*
Encode an object structure into JSON.
Arguments:
obj - An anonymous type representing the object
enc - Function definitions for querying JSOBJ type
buffer - Preallocated buffer to store result in. If NULL function allocates own buffer
cbBuffer - Length of buffer (ignored if buffer is NULL)
Returns:
Encoded JSON object as a null terminated char string.
NOTE:
If the supplied buffer wasn't enough to hold the result the function will allocate a new buffer.
Life cycle of the provided buffer must still be handled by caller.
If the return value doesn't equal the specified buffer caller must release the memory using
JSONObjectEncoder.free or free() as specified when calling this function.
*/
EXPORTFUNCTION char *JSON_EncodeObject(JSOBJ obj, JSONObjectEncoder *enc, char *buffer, size_t cbBuffer);
typedef struct __JSONObjectDecoder
{
JSOBJ (*newString)(void *prv, JSUINT32 *start, JSUINT32 *end);
void (*objectAddKey)(void *prv, JSOBJ obj, JSOBJ name, JSOBJ value);
void (*arrayAddItem)(void *prv, JSOBJ obj, JSOBJ value);
JSOBJ (*newTrue)(void *prv);
JSOBJ (*newFalse)(void *prv);
JSOBJ (*newNull)(void *prv);
JSOBJ (*newObject)(void *prv);
JSOBJ (*newArray)(void *prv);
JSOBJ (*newInt)(void *prv, JSINT32 value);
JSOBJ (*newLong)(void *prv, JSINT64 value);
JSOBJ (*newUnsignedLong)(void *prv, JSUINT64 value);
JSOBJ (*newDouble)(void *prv, double value);
void (*releaseObject)(void *prv, JSOBJ obj);
JSPFN_MALLOC malloc;
JSPFN_FREE free;
JSPFN_REALLOC realloc;
char *errorStr;
char *errorOffset;
int preciseFloat;
void *prv;
} JSONObjectDecoder;
EXPORTFUNCTION JSOBJ JSON_DecodeObject(JSONObjectDecoder *dec, const char *buffer, size_t cbBuffer);
#endif
srsly-release-v2.5.1/srsly/ujson/lib/ultrajsondec.c 0000664 0000000 0000000 00000052314 14742310675 0022436 0 ustar 00root root 0000000 0000000 /*
Developed by ESN, an Electronic Arts Inc. studio.
Copyright (c) 2014, Electronic Arts Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of ESN, Electronic Arts Inc. nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
http://code.google.com/p/stringencoders/
Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
Numeric decoder derived from from TCL library
http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
* Copyright (c) 1988-1993 The Regents of the University of California.
* Copyright (c) 1994 Sun Microsystems, Inc.
*/
#include "ultrajson.h"
#include
#include
#include
#include
#include
#include
#ifndef TRUE
#define TRUE 1
#define FALSE 0
#endif
#ifndef NULL
#define NULL 0
#endif
struct DecoderState
{
char *start;
char *end;
JSUINT32 *escStart;
JSUINT32 *escEnd;
int escHeap;
int lastType;
JSUINT32 objDepth;
void *prv;
JSONObjectDecoder *dec;
};
JSOBJ FASTCALL_MSVC decode_any( struct DecoderState *ds) FASTCALL_ATTR;
typedef JSOBJ (*PFN_DECODER)( struct DecoderState *ds);
static JSOBJ SetError( struct DecoderState *ds, int offset, const char *message)
{
ds->dec->errorOffset = ds->start + offset;
ds->dec->errorStr = (char *) message;
return NULL;
}
double createDouble(double intNeg, double intValue, double frcValue, int frcDecimalCount)
{
static const double g_pow10[] = {1.0, 0.1, 0.01, 0.001, 0.0001, 0.00001, 0.000001,0.0000001, 0.00000001, 0.000000001, 0.0000000001, 0.00000000001, 0.000000000001, 0.0000000000001, 0.00000000000001, 0.000000000000001};
return (intValue + (frcValue * g_pow10[frcDecimalCount])) * intNeg;
}
FASTCALL_ATTR JSOBJ FASTCALL_MSVC decodePreciseFloat(struct DecoderState *ds)
{
char *end;
double value;
errno = 0;
value = strtod(ds->start, &end);
if (errno == ERANGE)
{
return SetError(ds, -1, "Range error when decoding numeric as double");
}
ds->start = end;
return ds->dec->newDouble(ds->prv, value);
}
FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_numeric (struct DecoderState *ds)
{
int intNeg = 1;
int mantSize = 0;
JSUINT64 intValue;
JSUINT64 prevIntValue;
int chr;
int decimalCount = 0;
double frcValue = 0.0;
double expNeg;
double expValue;
char *offset = ds->start;
JSUINT64 overflowLimit = LLONG_MAX;
if (*(offset) == '-')
{
offset ++;
intNeg = -1;
overflowLimit = LLONG_MIN;
}
// Scan integer part
intValue = 0;
while (1)
{
chr = (int) (unsigned char) *(offset);
switch (chr)
{
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
{
//PERF: Don't do 64-bit arithmetic here unless we know we have to
prevIntValue = intValue;
intValue = intValue * 10ULL + (JSLONG) (chr - 48);
if (intNeg == 1 && prevIntValue > intValue)
{
return SetError(ds, -1, "Value is too big!");
}
else if (intNeg == -1 && intValue > overflowLimit)
{
return SetError(ds, -1, overflowLimit == LLONG_MAX ? "Value is too big!" : "Value is too small");
}
offset ++;
mantSize ++;
break;
}
case '.':
{
offset ++;
goto DECODE_FRACTION;
break;
}
case 'e':
case 'E':
{
offset ++;
goto DECODE_EXPONENT;
break;
}
default:
{
goto BREAK_INT_LOOP;
break;
}
}
}
BREAK_INT_LOOP:
ds->lastType = JT_INT;
ds->start = offset;
if (intNeg == 1 && (intValue & 0x8000000000000000ULL) != 0)
{
return ds->dec->newUnsignedLong(ds->prv, intValue);
}
else if ((intValue >> 31))
{
return ds->dec->newLong(ds->prv, (JSINT64) (intValue * (JSINT64) intNeg));
}
else
{
return ds->dec->newInt(ds->prv, (JSINT32) (intValue * intNeg));
}
DECODE_FRACTION:
if (ds->dec->preciseFloat)
{
return decodePreciseFloat(ds);
}
// Scan fraction part
frcValue = 0.0;
for (;;)
{
chr = (int) (unsigned char) *(offset);
switch (chr)
{
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
{
if (decimalCount < JSON_DOUBLE_MAX_DECIMALS)
{
frcValue = frcValue * 10.0 + (double) (chr - 48);
decimalCount ++;
}
offset ++;
break;
}
case 'e':
case 'E':
{
offset ++;
goto DECODE_EXPONENT;
break;
}
default:
{
goto BREAK_FRC_LOOP;
}
}
}
BREAK_FRC_LOOP:
//FIXME: Check for arithemtic overflow here
ds->lastType = JT_DOUBLE;
ds->start = offset;
return ds->dec->newDouble (ds->prv, createDouble( (double) intNeg, (double) intValue, frcValue, decimalCount));
DECODE_EXPONENT:
if (ds->dec->preciseFloat)
{
return decodePreciseFloat(ds);
}
expNeg = 1.0;
if (*(offset) == '-')
{
expNeg = -1.0;
offset ++;
}
else
if (*(offset) == '+')
{
expNeg = +1.0;
offset ++;
}
expValue = 0.0;
for (;;)
{
chr = (int) (unsigned char) *(offset);
switch (chr)
{
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
{
expValue = expValue * 10.0 + (double) (chr - 48);
offset ++;
break;
}
default:
{
goto BREAK_EXP_LOOP;
}
}
}
BREAK_EXP_LOOP:
//FIXME: Check for arithemtic overflow here
ds->lastType = JT_DOUBLE;
ds->start = offset;
return ds->dec->newDouble (ds->prv, createDouble( (double) intNeg, (double) intValue , frcValue, decimalCount) * pow(10.0, expValue * expNeg));
}
FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_true ( struct DecoderState *ds)
{
char *offset = ds->start;
offset ++;
if (*(offset++) != 'r')
goto SETERROR;
if (*(offset++) != 'u')
goto SETERROR;
if (*(offset++) != 'e')
goto SETERROR;
ds->lastType = JT_TRUE;
ds->start = offset;
return ds->dec->newTrue(ds->prv);
SETERROR:
return SetError(ds, -1, "Unexpected character found when decoding 'true'");
}
FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_false ( struct DecoderState *ds)
{
char *offset = ds->start;
offset ++;
if (*(offset++) != 'a')
goto SETERROR;
if (*(offset++) != 'l')
goto SETERROR;
if (*(offset++) != 's')
goto SETERROR;
if (*(offset++) != 'e')
goto SETERROR;
ds->lastType = JT_FALSE;
ds->start = offset;
return ds->dec->newFalse(ds->prv);
SETERROR:
return SetError(ds, -1, "Unexpected character found when decoding 'false'");
}
FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_null ( struct DecoderState *ds)
{
char *offset = ds->start;
offset ++;
if (*(offset++) != 'u')
goto SETERROR;
if (*(offset++) != 'l')
goto SETERROR;
if (*(offset++) != 'l')
goto SETERROR;
ds->lastType = JT_NULL;
ds->start = offset;
return ds->dec->newNull(ds->prv);
SETERROR:
return SetError(ds, -1, "Unexpected character found when decoding 'null'");
}
FASTCALL_ATTR void FASTCALL_MSVC SkipWhitespace(struct DecoderState *ds)
{
char *offset = ds->start;
for (;;)
{
switch (*offset)
{
case ' ':
case '\t':
case '\r':
case '\n':
offset ++;
break;
default:
ds->start = offset;
return;
}
}
}
enum DECODESTRINGSTATE
{
DS_ISNULL = 0x32,
DS_ISQUOTE,
DS_ISESCAPE,
DS_UTFLENERROR,
};
static const JSUINT8 g_decoderLookup[256] =
{
/* 0x00 */ DS_ISNULL, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x10 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x20 */ 1, 1, DS_ISQUOTE, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x30 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x40 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x50 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, DS_ISESCAPE, 1, 1, 1,
/* 0x60 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x70 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x80 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x90 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0xa0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0xb0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0xc0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
/* 0xd0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
/* 0xe0 */ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
/* 0xf0 */ 4, 4, 4, 4, 4, 4, 4, 4, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR, DS_UTFLENERROR,
};
FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_string ( struct DecoderState *ds)
{
int index;
JSUINT32 *escOffset;
JSUINT32 *escStart;
size_t escLen = (ds->escEnd - ds->escStart);
JSUINT8 *inputOffset;
JSUTF16 ch = 0;
JSUINT8 *lastHighSurrogate = NULL;
JSUINT8 oct;
JSUTF32 ucs;
ds->lastType = JT_INVALID;
ds->start ++;
if ( (size_t) (ds->end - ds->start) > escLen)
{
size_t newSize = (ds->end - ds->start);
if (ds->escHeap)
{
if (newSize > (SIZE_MAX / sizeof(JSUINT32)))
{
return SetError(ds, -1, "Could not reserve memory block");
}
escStart = (JSUINT32 *)ds->dec->realloc(ds->escStart, newSize * sizeof(JSUINT32));
if (!escStart)
{
ds->dec->free(ds->escStart);
return SetError(ds, -1, "Could not reserve memory block");
}
ds->escStart = escStart;
}
else
{
JSUINT32 *oldStart = ds->escStart;
if (newSize > (SIZE_MAX / sizeof(JSUINT32)))
{
return SetError(ds, -1, "Could not reserve memory block");
}
ds->escStart = (JSUINT32 *) ds->dec->malloc(newSize * sizeof(JSUINT32));
if (!ds->escStart)
{
return SetError(ds, -1, "Could not reserve memory block");
}
ds->escHeap = 1;
memcpy(ds->escStart, oldStart, escLen * sizeof(JSUINT32));
}
ds->escEnd = ds->escStart + newSize;
}
escOffset = ds->escStart;
inputOffset = (JSUINT8 *) ds->start;
for (;;)
{
switch (g_decoderLookup[(JSUINT8)(*inputOffset)])
{
case DS_ISNULL:
{
return SetError(ds, -1, "Unmatched ''\"' when when decoding 'string'");
}
case DS_ISQUOTE:
{
ds->lastType = JT_UTF8;
inputOffset ++;
ds->start += ( (char *) inputOffset - (ds->start));
return ds->dec->newString(ds->prv, ds->escStart, escOffset);
}
case DS_UTFLENERROR:
{
return SetError (ds, -1, "Invalid UTF-8 sequence length when decoding 'string'");
}
case DS_ISESCAPE:
inputOffset ++;
switch (*inputOffset)
{
case '\\': *(escOffset++) = '\\'; inputOffset++; continue;
case '\"': *(escOffset++) = '\"'; inputOffset++; continue;
case '/': *(escOffset++) = '/'; inputOffset++; continue;
case 'b': *(escOffset++) = '\b'; inputOffset++; continue;
case 'f': *(escOffset++) = '\f'; inputOffset++; continue;
case 'n': *(escOffset++) = '\n'; inputOffset++; continue;
case 'r': *(escOffset++) = '\r'; inputOffset++; continue;
case 't': *(escOffset++) = '\t'; inputOffset++; continue;
case 'u':
{
int index;
inputOffset ++;
for (index = 0; index < 4; index ++)
{
switch (*inputOffset)
{
case '\0': return SetError (ds, -1, "Unterminated unicode escape sequence when decoding 'string'");
default: return SetError (ds, -1, "Unexpected character in unicode escape sequence when decoding 'string'");
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
ch = (ch << 4) + (JSUTF16) (*inputOffset - '0');
break;
case 'a':
case 'b':
case 'c':
case 'd':
case 'e':
case 'f':
ch = (ch << 4) + 10 + (JSUTF16) (*inputOffset - 'a');
break;
case 'A':
case 'B':
case 'C':
case 'D':
case 'E':
case 'F':
ch = (ch << 4) + 10 + (JSUTF16) (*inputOffset - 'A');
break;
}
inputOffset ++;
}
if ((ch & 0xfc00) == 0xdc00 && lastHighSurrogate == inputOffset - 6 * sizeof(*inputOffset))
{
// Low surrogate immediately following a high surrogate
// Overwrite existing high surrogate with combined character
*(escOffset-1) = (((*(escOffset-1) - 0xd800) <<10) | (ch - 0xdc00)) + 0x10000;
}
else
{
*(escOffset++) = (JSUINT32) ch;
}
if ((ch & 0xfc00) == 0xd800)
{
lastHighSurrogate = inputOffset;
}
break;
}
case '\0': return SetError(ds, -1, "Unterminated escape sequence when decoding 'string'");
default: return SetError(ds, -1, "Unrecognized escape sequence when decoding 'string'");
}
break;
case 1:
{
*(escOffset++) = (JSUINT32) (*inputOffset++);
break;
}
case 2:
{
ucs = (*inputOffset++) & 0x1f;
ucs <<= 6;
if (((*inputOffset) & 0x80) != 0x80)
{
return SetError(ds, -1, "Invalid octet in UTF-8 sequence when decoding 'string'");
}
ucs |= (*inputOffset++) & 0x3f;
if (ucs < 0x80) return SetError (ds, -1, "Overlong 2 byte UTF-8 sequence detected when decoding 'string'");
*(escOffset++) = (JSUINT32) ucs;
break;
}
case 3:
{
JSUTF32 ucs = 0;
ucs |= (*inputOffset++) & 0x0f;
for (index = 0; index < 2; index ++)
{
ucs <<= 6;
oct = (*inputOffset++);
if ((oct & 0x80) != 0x80)
{
return SetError(ds, -1, "Invalid octet in UTF-8 sequence when decoding 'string'");
}
ucs |= oct & 0x3f;
}
if (ucs < 0x800) return SetError (ds, -1, "Overlong 3 byte UTF-8 sequence detected when encoding string");
*(escOffset++) = (JSUINT32) ucs;
break;
}
case 4:
{
JSUTF32 ucs = 0;
ucs |= (*inputOffset++) & 0x07;
for (index = 0; index < 3; index ++)
{
ucs <<= 6;
oct = (*inputOffset++);
if ((oct & 0x80) != 0x80)
{
return SetError(ds, -1, "Invalid octet in UTF-8 sequence when decoding 'string'");
}
ucs |= oct & 0x3f;
}
if (ucs < 0x10000) return SetError (ds, -1, "Overlong 4 byte UTF-8 sequence detected when decoding 'string'");
*(escOffset++) = (JSUINT32) ucs;
break;
}
}
}
}
FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_array(struct DecoderState *ds)
{
JSOBJ itemValue;
JSOBJ newObj;
int len;
ds->objDepth++;
if (ds->objDepth > JSON_MAX_OBJECT_DEPTH) {
return SetError(ds, -1, "Reached object decoding depth limit");
}
newObj = ds->dec->newArray(ds->prv);
len = 0;
ds->lastType = JT_INVALID;
ds->start ++;
for (;;)
{
SkipWhitespace(ds);
if ((*ds->start) == ']')
{
ds->objDepth--;
if (len == 0)
{
ds->start ++;
return newObj;
}
ds->dec->releaseObject(ds->prv, newObj);
return SetError(ds, -1, "Unexpected character found when decoding array value (1)");
}
itemValue = decode_any(ds);
if (itemValue == NULL)
{
ds->dec->releaseObject(ds->prv, newObj);
return NULL;
}
ds->dec->arrayAddItem (ds->prv, newObj, itemValue);
SkipWhitespace(ds);
switch (*(ds->start++))
{
case ']':
{
ds->objDepth--;
return newObj;
}
case ',':
break;
default:
ds->dec->releaseObject(ds->prv, newObj);
return SetError(ds, -1, "Unexpected character found when decoding array value (2)");
}
len ++;
}
}
FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_object( struct DecoderState *ds)
{
JSOBJ itemName;
JSOBJ itemValue;
JSOBJ newObj;
int len;
ds->objDepth++;
if (ds->objDepth > JSON_MAX_OBJECT_DEPTH) {
return SetError(ds, -1, "Reached object decoding depth limit");
}
newObj = ds->dec->newObject(ds->prv);
len = 0;
ds->start ++;
for (;;)
{
SkipWhitespace(ds);
if ((*ds->start) == '}')
{
ds->objDepth--;
if (len == 0)
{
ds->start ++;
return newObj;
}
ds->dec->releaseObject(ds->prv, newObj);
return SetError(ds, -1, "Unexpected character in found when decoding object value");
}
ds->lastType = JT_INVALID;
itemName = decode_any(ds);
if (itemName == NULL)
{
ds->dec->releaseObject(ds->prv, newObj);
return NULL;
}
if (ds->lastType != JT_UTF8)
{
ds->dec->releaseObject(ds->prv, newObj);
ds->dec->releaseObject(ds->prv, itemName);
return SetError(ds, -1, "Key name of object must be 'string' when decoding 'object'");
}
SkipWhitespace(ds);
if (*(ds->start++) != ':')
{
ds->dec->releaseObject(ds->prv, newObj);
ds->dec->releaseObject(ds->prv, itemName);
return SetError(ds, -1, "No ':' found when decoding object value");
}
SkipWhitespace(ds);
itemValue = decode_any(ds);
if (itemValue == NULL)
{
ds->dec->releaseObject(ds->prv, newObj);
ds->dec->releaseObject(ds->prv, itemName);
return NULL;
}
ds->dec->objectAddKey (ds->prv, newObj, itemName, itemValue);
SkipWhitespace(ds);
switch (*(ds->start++))
{
case '}':
{
ds->objDepth--;
return newObj;
}
case ',':
break;
default:
ds->dec->releaseObject(ds->prv, newObj);
return SetError(ds, -1, "Unexpected character in found when decoding object value");
}
len++;
}
}
FASTCALL_ATTR JSOBJ FASTCALL_MSVC decode_any(struct DecoderState *ds)
{
for (;;)
{
switch (*ds->start)
{
case '\"':
return decode_string (ds);
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
case '-':
return decode_numeric (ds);
case '[': return decode_array (ds);
case '{': return decode_object (ds);
case 't': return decode_true (ds);
case 'f': return decode_false (ds);
case 'n': return decode_null (ds);
case ' ':
case '\t':
case '\r':
case '\n':
// White space
ds->start ++;
break;
default:
return SetError(ds, -1, "Expected object or value");
}
}
}
JSOBJ JSON_DecodeObject(JSONObjectDecoder *dec, const char *buffer, size_t cbBuffer)
{
/*
FIXME: Base the size of escBuffer of that of cbBuffer so that the unicode escaping doesn't run into the wall each time */
struct DecoderState ds;
JSUINT32 escBuffer[(JSON_MAX_STACK_BUFFER_SIZE / sizeof(JSUINT32))];
JSOBJ ret;
ds.start = (char *) buffer;
ds.end = ds.start + cbBuffer;
ds.escStart = escBuffer;
ds.escEnd = ds.escStart + (JSON_MAX_STACK_BUFFER_SIZE / sizeof(JSUINT32));
ds.escHeap = 0;
ds.prv = dec->prv;
ds.dec = dec;
ds.dec->errorStr = NULL;
ds.dec->errorOffset = NULL;
ds.objDepth = 0;
ds.dec = dec;
ret = decode_any (&ds);
if (ds.escHeap)
{
dec->free(ds.escStart);
}
if (!(dec->errorStr))
{
if ((ds.end - ds.start) > 0)
{
SkipWhitespace(&ds);
}
if (ds.start != ds.end && ret)
{
dec->releaseObject(ds.prv, ret);
return SetError(&ds, -1, "Trailing data");
}
}
return ret;
}
srsly-release-v2.5.1/srsly/ujson/lib/ultrajsonenc.c 0000664 0000000 0000000 00000063335 14742310675 0022455 0 ustar 00root root 0000000 0000000 /*
Developed by ESN, an Electronic Arts Inc. studio.
Copyright (c) 2014, Electronic Arts Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of ESN, Electronic Arts Inc. nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
http://code.google.com/p/stringencoders/
Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
Numeric decoder derived from from TCL library
http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
* Copyright (c) 1988-1993 The Regents of the University of California.
* Copyright (c) 1994 Sun Microsystems, Inc.
*/
#include "ultrajson.h"
#include
#include
#include
#include
#include
#include
#include
#ifndef TRUE
#define TRUE 1
#endif
#ifndef FALSE
#define FALSE 0
#endif
#if ( (defined(_WIN32) || defined(WIN32) ) && ( defined(_MSC_VER) ) )
#define snprintf sprintf_s
#endif
/*
Worst cases being:
Control characters (ASCII < 32)
0x00 (1 byte) input => \u0000 output (6 bytes)
1 * 6 => 6 (6 bytes required)
or UTF-16 surrogate pairs
4 bytes input in UTF-8 => \uXXXX\uYYYY (12 bytes).
4 * 6 => 24 bytes (12 bytes required)
The extra 2 bytes are for the quotes around the string
*/
#define RESERVE_STRING(_len) (2 + ((_len) * 6))
static const double g_pow10[] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000, 10000000000, 100000000000, 1000000000000, 10000000000000, 100000000000000, 1000000000000000};
static const char g_hexChars[] = "0123456789abcdef";
static const char g_escapeChars[] = "0123456789\\b\\t\\n\\f\\r\\\"\\\\\\/";
/*
FIXME: While this is fine dandy and working it's a magic value mess which probably only the author understands.
Needs a cleanup and more documentation */
/*
Table for pure ascii output escaping all characters above 127 to \uXXXX */
static const JSUINT8 g_asciiOutputTable[256] =
{
/* 0x00 */ 0, 30, 30, 30, 30, 30, 30, 30, 10, 12, 14, 30, 16, 18, 30, 30,
/* 0x10 */ 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
/* 0x20 */ 1, 1, 20, 1, 1, 1, 29, 1, 1, 1, 1, 1, 1, 1, 1, 24,
/* 0x30 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 29, 1, 29, 1,
/* 0x40 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x50 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 22, 1, 1, 1,
/* 0x60 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x70 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x80 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0x90 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0xa0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0xb0 */ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
/* 0xc0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
/* 0xd0 */ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
/* 0xe0 */ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
/* 0xf0 */ 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 1, 1
};
static void SetError (JSOBJ obj, JSONObjectEncoder *enc, const char *message)
{
enc->errorMsg = message;
enc->errorObj = obj;
}
/*
FIXME: Keep track of how big these get across several encoder calls and try to make an estimate
That way we won't run our head into the wall each call */
void Buffer_Realloc (JSONObjectEncoder *enc, size_t cbNeeded)
{
size_t free_space = enc->end - enc->offset;
if (free_space >= cbNeeded)
{
return;
}
size_t curSize = enc->end - enc->start;
size_t newSize = curSize;
size_t offset = enc->offset - enc->start;
#ifdef DEBUG
// In debug mode, allocate only what is requested so that any miscalculation
// shows up plainly as a crash.
newSize = (enc->offset - enc->start) + cbNeeded;
#else
while (newSize < curSize + cbNeeded)
{
newSize *= 2;
}
#endif
if (enc->heap)
{
enc->start = (char *) enc->realloc (enc->start, newSize);
if (!enc->start)
{
SetError (NULL, enc, "Could not reserve memory block");
return;
}
}
else
{
char *oldStart = enc->start;
enc->heap = 1;
enc->start = (char *) enc->malloc (newSize);
if (!enc->start)
{
SetError (NULL, enc, "Could not reserve memory block");
return;
}
memcpy (enc->start, oldStart, offset);
}
enc->offset = enc->start + offset;
enc->end = enc->start + newSize;
}
#define Buffer_Reserve(__enc, __len) \
if ( (size_t) ((__enc)->end - (__enc)->offset) < (size_t) (__len)) \
{ \
Buffer_Realloc((__enc), (__len));\
} \
FASTCALL_ATTR INLINE_PREFIX void FASTCALL_MSVC Buffer_AppendShortHexUnchecked (char *outputOffset, unsigned short value)
{
*(outputOffset++) = g_hexChars[(value & 0xf000) >> 12];
*(outputOffset++) = g_hexChars[(value & 0x0f00) >> 8];
*(outputOffset++) = g_hexChars[(value & 0x00f0) >> 4];
*(outputOffset++) = g_hexChars[(value & 0x000f) >> 0];
}
int Buffer_EscapeStringUnvalidated (JSONObjectEncoder *enc, const char *io, const char *end)
{
char *of = (char *) enc->offset;
for (;;)
{
switch (*io)
{
case 0x00:
{
if (io < end)
{
*(of++) = '\\';
*(of++) = 'u';
*(of++) = '0';
*(of++) = '0';
*(of++) = '0';
*(of++) = '0';
break;
}
else
{
enc->offset += (of - enc->offset);
return TRUE;
}
}
case '\"': (*of++) = '\\'; (*of++) = '\"'; break;
case '\\': (*of++) = '\\'; (*of++) = '\\'; break;
case '\b': (*of++) = '\\'; (*of++) = 'b'; break;
case '\f': (*of++) = '\\'; (*of++) = 'f'; break;
case '\n': (*of++) = '\\'; (*of++) = 'n'; break;
case '\r': (*of++) = '\\'; (*of++) = 'r'; break;
case '\t': (*of++) = '\\'; (*of++) = 't'; break;
case '/':
{
if (enc->escapeForwardSlashes)
{
(*of++) = '\\';
(*of++) = '/';
}
else
{
// Same as default case below.
(*of++) = (*io);
}
break;
}
case 0x26: // '&'
case 0x3c: // '<'
case 0x3e: // '>'
{
if (enc->encodeHTMLChars)
{
// Fall through to \u00XX case below.
}
else
{
// Same as default case below.
(*of++) = (*io);
break;
}
}
case 0x01:
case 0x02:
case 0x03:
case 0x04:
case 0x05:
case 0x06:
case 0x07:
case 0x0b:
case 0x0e:
case 0x0f:
case 0x10:
case 0x11:
case 0x12:
case 0x13:
case 0x14:
case 0x15:
case 0x16:
case 0x17:
case 0x18:
case 0x19:
case 0x1a:
case 0x1b:
case 0x1c:
case 0x1d:
case 0x1e:
case 0x1f:
{
*(of++) = '\\';
*(of++) = 'u';
*(of++) = '0';
*(of++) = '0';
*(of++) = g_hexChars[ (unsigned char) (((*io) & 0xf0) >> 4)];
*(of++) = g_hexChars[ (unsigned char) ((*io) & 0x0f)];
break;
}
default: (*of++) = (*io); break;
}
io++;
}
}
int Buffer_EscapeStringValidated (JSOBJ obj, JSONObjectEncoder *enc, const char *io, const char *end)
{
JSUTF32 ucs;
char *of = (char *) enc->offset;
for (;;)
{
#ifdef DEBUG
// 6 is the maximum length of a single character (cf. RESERVE_STRING).
if ((io < end) && (enc->end - of < 6)) {
fprintf(stderr, "Ran out of buffer space during Buffer_EscapeStringValidated()\n");
abort();
}
#endif
JSUINT8 utflen = g_asciiOutputTable[(unsigned char) *io];
switch (utflen)
{
case 0:
{
if (io < end)
{
*(of++) = '\\';
*(of++) = 'u';
*(of++) = '0';
*(of++) = '0';
*(of++) = '0';
*(of++) = '0';
io ++;
continue;
}
else
{
enc->offset += (of - enc->offset);
return TRUE;
}
}
case 1:
{
*(of++)= (*io++);
continue;
}
case 2:
{
JSUTF32 in;
JSUTF16 in16;
if (end - io < 1)
{
enc->offset += (of - enc->offset);
SetError (obj, enc, "Unterminated UTF-8 sequence when encoding string");
return FALSE;
}
memcpy(&in16, io, sizeof(JSUTF16));
in = (JSUTF32) in16;
#ifdef __LITTLE_ENDIAN__
ucs = ((in & 0x1f) << 6) | ((in >> 8) & 0x3f);
#else
ucs = ((in & 0x1f00) >> 2) | (in & 0x3f);
#endif
if (ucs < 0x80)
{
enc->offset += (of - enc->offset);
SetError (obj, enc, "Overlong 2 byte UTF-8 sequence detected when encoding string");
return FALSE;
}
io += 2;
break;
}
case 3:
{
JSUTF32 in;
JSUTF16 in16;
JSUINT8 in8;
if (end - io < 2)
{
enc->offset += (of - enc->offset);
SetError (obj, enc, "Unterminated UTF-8 sequence when encoding string");
return FALSE;
}
memcpy(&in16, io, sizeof(JSUTF16));
memcpy(&in8, io + 2, sizeof(JSUINT8));
#ifdef __LITTLE_ENDIAN__
in = (JSUTF32) in16;
in |= in8 << 16;
ucs = ((in & 0x0f) << 12) | ((in & 0x3f00) >> 2) | ((in & 0x3f0000) >> 16);
#else
in = in16 << 8;
in |= in8;
ucs = ((in & 0x0f0000) >> 4) | ((in & 0x3f00) >> 2) | (in & 0x3f);
#endif
if (ucs < 0x800)
{
enc->offset += (of - enc->offset);
SetError (obj, enc, "Overlong 3 byte UTF-8 sequence detected when encoding string");
return FALSE;
}
io += 3;
break;
}
case 4:
{
JSUTF32 in;
if (end - io < 3)
{
enc->offset += (of - enc->offset);
SetError (obj, enc, "Unterminated UTF-8 sequence when encoding string");
return FALSE;
}
memcpy(&in, io, sizeof(JSUTF32));
#ifdef __LITTLE_ENDIAN__
ucs = ((in & 0x07) << 18) | ((in & 0x3f00) << 4) | ((in & 0x3f0000) >> 10) | ((in & 0x3f000000) >> 24);
#else
ucs = ((in & 0x07000000) >> 6) | ((in & 0x3f0000) >> 4) | ((in & 0x3f00) >> 2) | (in & 0x3f);
#endif
if (ucs < 0x10000)
{
enc->offset += (of - enc->offset);
SetError (obj, enc, "Overlong 4 byte UTF-8 sequence detected when encoding string");
return FALSE;
}
io += 4;
break;
}
case 5:
case 6:
{
enc->offset += (of - enc->offset);
SetError (obj, enc, "Unsupported UTF-8 sequence length when encoding string");
return FALSE;
}
case 29:
{
if (enc->encodeHTMLChars)
{
// Fall through to \u00XX case 30 below.
}
else
{
// Same as case 1 above.
*(of++) = (*io++);
continue;
}
}
case 30:
{
// \uXXXX encode
*(of++) = '\\';
*(of++) = 'u';
*(of++) = '0';
*(of++) = '0';
*(of++) = g_hexChars[ (unsigned char) (((*io) & 0xf0) >> 4)];
*(of++) = g_hexChars[ (unsigned char) ((*io) & 0x0f)];
io ++;
continue;
}
case 10:
case 12:
case 14:
case 16:
case 18:
case 20:
case 22:
{
*(of++) = *( (char *) (g_escapeChars + utflen + 0));
*(of++) = *( (char *) (g_escapeChars + utflen + 1));
io ++;
continue;
}
case 24:
{
if (enc->escapeForwardSlashes)
{
*(of++) = *( (char *) (g_escapeChars + utflen + 0));
*(of++) = *( (char *) (g_escapeChars + utflen + 1));
io ++;
}
else
{
// Same as case 1 above.
*(of++) = (*io++);
}
continue;
}
// This can never happen, it's here to make L4 VC++ happy
default:
{
ucs = 0;
break;
}
}
/*
If the character is a UTF8 sequence of length > 1 we end up here */
if (ucs >= 0x10000)
{
ucs -= 0x10000;
*(of++) = '\\';
*(of++) = 'u';
Buffer_AppendShortHexUnchecked(of, (unsigned short) (ucs >> 10) + 0xd800);
of += 4;
*(of++) = '\\';
*(of++) = 'u';
Buffer_AppendShortHexUnchecked(of, (unsigned short) (ucs & 0x3ff) + 0xdc00);
of += 4;
}
else
{
*(of++) = '\\';
*(of++) = 'u';
Buffer_AppendShortHexUnchecked(of, (unsigned short) ucs);
of += 4;
}
}
}
static FASTCALL_ATTR INLINE_PREFIX void FASTCALL_MSVC Buffer_AppendCharUnchecked(JSONObjectEncoder *enc, char chr)
{
#ifdef DEBUG
if (enc->end <= enc->offset)
{
fprintf(stderr, "Overflow writing byte %d '%c'. The last few characters were:\n'''", chr, chr);
char * recent = enc->offset - 1000;
if (enc->start > recent)
{
recent = enc->start;
}
for (; recent < enc->offset; recent++)
{
fprintf(stderr, "%c", *recent);
}
fprintf(stderr, "'''\n");
abort();
}
#endif
*(enc->offset++) = chr;
}
FASTCALL_ATTR INLINE_PREFIX void FASTCALL_MSVC strreverse(char* begin, char* end)
{
char aux;
while (end > begin)
aux = *end, *end-- = *begin, *begin++ = aux;
}
void Buffer_AppendIndentNewlineUnchecked(JSONObjectEncoder *enc)
{
if (enc->indent > 0) Buffer_AppendCharUnchecked(enc, '\n');
}
void Buffer_AppendIndentUnchecked(JSONObjectEncoder *enc, JSINT32 value)
{
int i;
if (enc->indent > 0)
while (value-- > 0)
for (i = 0; i < enc->indent; i++)
Buffer_AppendCharUnchecked(enc, ' ');
}
void Buffer_AppendIntUnchecked(JSONObjectEncoder *enc, JSINT32 value)
{
char* wstr;
JSUINT32 uvalue = (value < 0) ? -value : value;
wstr = enc->offset;
// Conversion. Number is reversed.
do *wstr++ = (char)(48 + (uvalue % 10)); while(uvalue /= 10);
if (value < 0) *wstr++ = '-';
// Reverse string
strreverse(enc->offset,wstr - 1);
enc->offset += (wstr - (enc->offset));
}
void Buffer_AppendLongUnchecked(JSONObjectEncoder *enc, JSINT64 value)
{
char* wstr;
JSUINT64 uvalue = (value < 0) ? -value : value;
wstr = enc->offset;
// Conversion. Number is reversed.
do *wstr++ = (char)(48 + (uvalue % 10ULL)); while(uvalue /= 10ULL);
if (value < 0) *wstr++ = '-';
// Reverse string
strreverse(enc->offset,wstr - 1);
enc->offset += (wstr - (enc->offset));
}
void Buffer_AppendUnsignedLongUnchecked(JSONObjectEncoder *enc, JSUINT64 value)
{
char* wstr;
JSUINT64 uvalue = value;
wstr = enc->offset;
// Conversion. Number is reversed.
do *wstr++ = (char)(48 + (uvalue % 10ULL)); while(uvalue /= 10ULL);
// Reverse string
strreverse(enc->offset,wstr - 1);
enc->offset += (wstr - (enc->offset));
}
int Buffer_AppendDoubleUnchecked(JSOBJ obj, JSONObjectEncoder *enc, double value)
{
/* if input is larger than thres_max, revert to exponential */
const double thres_max = (double) 1e16 - 1;
int count;
double diff = 0.0;
char* str = enc->offset;
char* wstr = str;
unsigned long long whole;
double tmp;
unsigned long long frac;
int neg;
double pow10;
if (value == HUGE_VAL || value == -HUGE_VAL)
{
SetError (obj, enc, "Invalid Inf value when encoding double");
return FALSE;
}
if (!(value == value))
{
SetError (obj, enc, "Invalid Nan value when encoding double");
return FALSE;
}
/* we'll work in positive values and deal with the
negative sign issue later */
neg = 0;
if (value < 0)
{
neg = 1;
value = -value;
}
pow10 = g_pow10[enc->doublePrecision];
whole = (unsigned long long) value;
tmp = (value - whole) * pow10;
frac = (unsigned long long)(tmp);
diff = tmp - frac;
if (diff > 0.5)
{
++frac;
/* handle rollover, e.g. case 0.99 with prec 1 is 1.0 */
if (frac >= pow10)
{
frac = 0;
++whole;
}
}
else
if (diff == 0.5 && ((frac == 0) || (frac & 1)))
{
/* if halfway, round up if odd, OR
if last digit is 0. That last part is strange */
++frac;
}
/* for very large numbers switch back to native sprintf for exponentials.
anyone want to write code to replace this? */
/*
normal printf behavior is to print EVERY whole number digit
which can be 100s of characters overflowing your buffers == bad
*/
if (value > thres_max)
{
enc->offset += snprintf(str, enc->end - enc->offset, "%.15e", neg ? -value : value);
return TRUE;
}
if (enc->doublePrecision == 0)
{
diff = value - whole;
if (diff > 0.5)
{
/* greater than 0.5, round up, e.g. 1.6 -> 2 */
++whole;
}
else
if (diff == 0.5 && (whole & 1))
{
/* exactly 0.5 and ODD, then round up */
/* 1.5 -> 2, but 2.5 -> 2 */
++whole;
}
//vvvvvvvvvvvvvvvvvvv Diff from modp_dto2
}
else
if (frac)
{
count = enc->doublePrecision;
// now do fractional part, as an unsigned number
// we know it is not 0 but we can have leading zeros, these
// should be removed
while (!(frac % 10))
{
--count;
frac /= 10;
}
//^^^^^^^^^^^^^^^^^^^ Diff from modp_dto2
// now do fractional part, as an unsigned number
do
{
--count;
*wstr++ = (char)(48 + (frac % 10));
} while (frac /= 10);
// add extra 0s
while (count-- > 0)
{
*wstr++ = '0';
}
// add decimal
*wstr++ = '.';
}
else
{
*wstr++ = '0';
*wstr++ = '.';
}
// do whole part
// Take care of sign
// Conversion. Number is reversed.
do *wstr++ = (char)(48 + (whole % 10)); while (whole /= 10);
if (neg)
{
*wstr++ = '-';
}
strreverse(str, wstr-1);
enc->offset += (wstr - (enc->offset));
return TRUE;
}
/*
FIXME:
Handle integration functions returning NULL here */
/*
FIXME:
Perhaps implement recursion detection */
void encode(JSOBJ obj, JSONObjectEncoder *enc, const char *name, size_t cbName)
{
const char *value;
char *objName;
int count;
JSOBJ iterObj;
size_t szlen;
JSONTypeContext tc;
if (enc->level > enc->recursionMax)
{
SetError (obj, enc, "Maximum recursion level reached");
return;
}
if (enc->errorMsg)
{
return;
}
if (name)
{
// 2 extra for the colon and optional space after it
Buffer_Reserve(enc, RESERVE_STRING(cbName) + 2);
Buffer_AppendCharUnchecked(enc, '\"');
if (enc->forceASCII)
{
if (!Buffer_EscapeStringValidated(obj, enc, name, name + cbName))
{
return;
}
}
else
{
if (!Buffer_EscapeStringUnvalidated(enc, name, name + cbName))
{
return;
}
}
Buffer_AppendCharUnchecked(enc, '\"');
Buffer_AppendCharUnchecked (enc, ':');
}
tc.encoder_prv = enc->prv;
enc->beginTypeContext(obj, &tc, enc);
/*
This reservation covers any additions on non-variable parts below, specifically:
- Opening brackets for JT_ARRAY and JT_OBJECT
- Number representation for JT_LONG, JT_ULONG, JT_INT, and JT_DOUBLE
- Constant value for JT_TRUE, JT_FALSE, JT_NULL
The length of 128 is the worst case length of the Buffer_AppendDoubleDconv addition.
The other types above all have smaller representations.
*/
Buffer_Reserve (enc, 128);
switch (tc.type)
{
case JT_INVALID:
{
return;
}
case JT_ARRAY:
{
count = 0;
Buffer_AppendCharUnchecked (enc, '[');
Buffer_AppendIndentNewlineUnchecked (enc);
while (enc->iterNext(obj, &tc))
{
// The extra 2 bytes cover the comma and (optional) newline.
Buffer_Reserve (enc, enc->indent * (enc->level + 1) + 2);
if (count > 0)
{
Buffer_AppendCharUnchecked (enc, ',');
Buffer_AppendIndentNewlineUnchecked (enc);
}
iterObj = enc->iterGetValue(obj, &tc);
enc->level ++;
Buffer_AppendIndentUnchecked (enc, enc->level);
encode (iterObj, enc, NULL, 0);
count ++;
}
enc->iterEnd(obj, &tc);
// Reserve space for the indentation plus the newline.
Buffer_Reserve (enc, enc->indent * enc->level + 1);
Buffer_AppendIndentNewlineUnchecked (enc);
Buffer_AppendIndentUnchecked (enc, enc->level);
Buffer_Reserve (enc, 1);
Buffer_AppendCharUnchecked (enc, ']');
break;
}
case JT_OBJECT:
{
count = 0;
Buffer_AppendCharUnchecked (enc, '{');
Buffer_AppendIndentNewlineUnchecked (enc);
while (enc->iterNext(obj, &tc))
{
// The extra 2 bytes cover the comma and optional newline.
Buffer_Reserve (enc, enc->indent * (enc->level + 1) + 2);
if (count > 0)
{
Buffer_AppendCharUnchecked (enc, ',');
Buffer_AppendIndentNewlineUnchecked (enc);
}
iterObj = enc->iterGetValue(obj, &tc);
objName = enc->iterGetName(obj, &tc, &szlen);
enc->level ++;
Buffer_AppendIndentUnchecked (enc, enc->level);
encode (iterObj, enc, objName, szlen);
count ++;
}
enc->iterEnd(obj, &tc);
Buffer_Reserve (enc, enc->indent * enc->level + 1);
Buffer_AppendIndentNewlineUnchecked (enc);
Buffer_AppendIndentUnchecked (enc, enc->level);
Buffer_Reserve (enc, 1);
Buffer_AppendCharUnchecked (enc, '}');
break;
}
case JT_LONG:
{
Buffer_AppendLongUnchecked (enc, enc->getLongValue(obj, &tc));
break;
}
case JT_ULONG:
{
Buffer_AppendUnsignedLongUnchecked (enc, enc->getUnsignedLongValue(obj, &tc));
break;
}
case JT_INT:
{
Buffer_AppendIntUnchecked (enc, enc->getIntValue(obj, &tc));
break;
}
case JT_TRUE:
{
Buffer_AppendCharUnchecked (enc, 't');
Buffer_AppendCharUnchecked (enc, 'r');
Buffer_AppendCharUnchecked (enc, 'u');
Buffer_AppendCharUnchecked (enc, 'e');
break;
}
case JT_FALSE:
{
Buffer_AppendCharUnchecked (enc, 'f');
Buffer_AppendCharUnchecked (enc, 'a');
Buffer_AppendCharUnchecked (enc, 'l');
Buffer_AppendCharUnchecked (enc, 's');
Buffer_AppendCharUnchecked (enc, 'e');
break;
}
case JT_NULL:
{
Buffer_AppendCharUnchecked (enc, 'n');
Buffer_AppendCharUnchecked (enc, 'u');
Buffer_AppendCharUnchecked (enc, 'l');
Buffer_AppendCharUnchecked (enc, 'l');
break;
}
case JT_DOUBLE:
{
if (!Buffer_AppendDoubleUnchecked (obj, enc, enc->getDoubleValue(obj, &tc)))
{
enc->endTypeContext(obj, &tc);
enc->level --;
return;
}
break;
}
case JT_UTF8:
{
value = enc->getStringValue(obj, &tc, &szlen);
if(!value)
{
SetError(obj, enc, "utf-8 encoding error");
return;
}
Buffer_Reserve(enc, RESERVE_STRING(szlen));
if (enc->errorMsg)
{
enc->endTypeContext(obj, &tc);
return;
}
Buffer_AppendCharUnchecked (enc, '\"');
if (enc->forceASCII)
{
if (!Buffer_EscapeStringValidated(obj, enc, value, value + szlen))
{
enc->endTypeContext(obj, &tc);
enc->level --;
return;
}
}
else
{
if (!Buffer_EscapeStringUnvalidated(enc, value, value + szlen))
{
enc->endTypeContext(obj, &tc);
enc->level --;
return;
}
}
Buffer_AppendCharUnchecked (enc, '\"');
break;
}
case JT_RAW:
{
value = enc->getStringValue(obj, &tc, &szlen);
if(!value)
{
SetError(obj, enc, "utf-8 encoding error");
return;
}
Buffer_Reserve(enc, szlen);
if (enc->errorMsg)
{
enc->endTypeContext(obj, &tc);
return;
}
memcpy(enc->offset, value, szlen);
enc->offset += szlen;
break;
}
}
enc->endTypeContext(obj, &tc);
enc->level --;
}
char *JSON_EncodeObject(JSOBJ obj, JSONObjectEncoder *enc, char *_buffer, size_t _cbBuffer)
{
enc->malloc = enc->malloc ? enc->malloc : malloc;
enc->free = enc->free ? enc->free : free;
enc->realloc = enc->realloc ? enc->realloc : realloc;
enc->errorMsg = NULL;
enc->errorObj = NULL;
enc->level = 0;
if (enc->recursionMax < 1)
{
enc->recursionMax = JSON_MAX_RECURSION_DEPTH;
}
if (enc->doublePrecision < 0 ||
enc->doublePrecision > JSON_DOUBLE_MAX_DECIMALS)
{
enc->doublePrecision = JSON_DOUBLE_MAX_DECIMALS;
}
if (_buffer == NULL)
{
_cbBuffer = 32768;
enc->start = (char *) enc->malloc (_cbBuffer);
if (!enc->start)
{
SetError(obj, enc, "Could not reserve memory block");
return NULL;
}
enc->heap = 1;
}
else
{
enc->start = _buffer;
enc->heap = 0;
}
enc->end = enc->start + _cbBuffer;
enc->offset = enc->start;
encode (obj, enc, NULL, 0);
Buffer_Reserve(enc, 1);
if (enc->errorMsg)
{
return NULL;
}
Buffer_AppendCharUnchecked(enc, '\0');
return enc->start;
}
srsly-release-v2.5.1/srsly/ujson/objToJSON.c 0000664 0000000 0000000 00000054661 14742310675 0020751 0 ustar 00root root 0000000 0000000 /*
Developed by ESN, an Electronic Arts Inc. studio.
Copyright (c) 2014, Electronic Arts Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of ESN, Electronic Arts Inc. nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
http://code.google.com/p/stringencoders/
Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
Numeric decoder derived from from TCL library
http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
* Copyright (c) 1988-1993 The Regents of the University of California.
* Copyright (c) 1994 Sun Microsystems, Inc.
*/
#include "py_defines.h"
#include
#include
#include
#define EPOCH_ORD 719163
static PyObject* type_decimal = NULL;
typedef void *(*PFN_PyTypeToJSON)(JSOBJ obj, JSONTypeContext *ti, void *outValue, size_t *_outLen);
#if (PY_VERSION_HEX < 0x02050000)
typedef ssize_t Py_ssize_t;
#endif
typedef struct __TypeContext
{
JSPFN_ITEREND iterEnd;
JSPFN_ITERNEXT iterNext;
JSPFN_ITERGETNAME iterGetName;
JSPFN_ITERGETVALUE iterGetValue;
PFN_PyTypeToJSON PyTypeToJSON;
PyObject *newObj;
PyObject *dictObj;
Py_ssize_t index;
Py_ssize_t size;
PyObject *itemValue;
PyObject *itemName;
PyObject *attrList;
PyObject *iterator;
union
{
PyObject *rawJSONValue;
JSINT64 longValue;
JSUINT64 unsignedLongValue;
};
} TypeContext;
#define GET_TC(__ptrtc) ((TypeContext *)((__ptrtc)->prv))
struct PyDictIterState
{
PyObject *keys;
size_t i;
size_t sz;
};
//#define PRINTMARK() fprintf(stderr, "%s: MARK(%d)\n", __FILE__, __LINE__)
#define PRINTMARK()
void initObjToJSON(void)
{
PyObject* mod_decimal = PyImport_ImportModule("decimal");
if (mod_decimal)
{
type_decimal = PyObject_GetAttrString(mod_decimal, "Decimal");
Py_INCREF(type_decimal);
Py_DECREF(mod_decimal);
}
else
PyErr_Clear();
PyDateTime_IMPORT;
}
#ifdef _LP64
static void *PyIntToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
PyObject *obj = (PyObject *) _obj;
*((JSINT64 *) outValue) = PyInt_AS_LONG (obj);
return NULL;
}
#else
static void *PyIntToINT32(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
PyObject *obj = (PyObject *) _obj;
*((JSINT32 *) outValue) = PyInt_AS_LONG (obj);
return NULL;
}
#endif
static void *PyLongToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
*((JSINT64 *) outValue) = GET_TC(tc)->longValue;
return NULL;
}
static void *PyLongToUINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
*((JSUINT64 *) outValue) = GET_TC(tc)->unsignedLongValue;
return NULL;
}
static void *PyFloatToDOUBLE(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
PyObject *obj = (PyObject *) _obj;
*((double *) outValue) = PyFloat_AsDouble (obj);
return NULL;
}
static void *PyStringToUTF8(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
PyObject *obj = (PyObject *) _obj;
*_outLen = PyString_GET_SIZE(obj);
return PyString_AS_STRING(obj);
}
static void *PyUnicodeToUTF8(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
PyObject *obj = (PyObject *) _obj;
PyObject *newObj;
#if (PY_VERSION_HEX >= 0x03030000)
if(PyUnicode_IS_COMPACT_ASCII(obj))
{
Py_ssize_t len;
char *data = PyUnicode_AsUTF8AndSize(obj, &len);
*_outLen = len;
return data;
}
#endif
newObj = PyUnicode_AsUTF8String(obj);
if(!newObj)
{
return NULL;
}
GET_TC(tc)->newObj = newObj;
*_outLen = PyString_GET_SIZE(newObj);
return PyString_AS_STRING(newObj);
}
static void *PyRawJSONToUTF8(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
PyObject *obj = GET_TC(tc)->rawJSONValue;
if (PyUnicode_Check(obj)) {
return PyUnicodeToUTF8(obj, tc, outValue, _outLen);
}
else {
return PyStringToUTF8(obj, tc, outValue, _outLen);
}
}
static void *PyDateTimeToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
PyObject *obj = (PyObject *) _obj;
PyObject *date, *ord, *utcoffset;
int y, m, d, h, mn, s, days;
utcoffset = PyObject_CallMethod(obj, "utcoffset", NULL);
if(utcoffset != Py_None){
obj = PyNumber_Subtract(obj, utcoffset);
}
y = PyDateTime_GET_YEAR(obj);
m = PyDateTime_GET_MONTH(obj);
d = PyDateTime_GET_DAY(obj);
h = PyDateTime_DATE_GET_HOUR(obj);
mn = PyDateTime_DATE_GET_MINUTE(obj);
s = PyDateTime_DATE_GET_SECOND(obj);
date = PyDate_FromDate(y, m, 1);
ord = PyObject_CallMethod(date, "toordinal", NULL);
days = PyInt_AS_LONG(ord) - EPOCH_ORD + d - 1;
Py_DECREF(date);
Py_DECREF(ord);
*( (JSINT64 *) outValue) = (((JSINT64) ((days * 24 + h) * 60 + mn)) * 60 + s);
return NULL;
}
static void *PyDateToINT64(JSOBJ _obj, JSONTypeContext *tc, void *outValue, size_t *_outLen)
{
PyObject *obj = (PyObject *) _obj;
PyObject *date, *ord;
int y, m, d, days;
y = PyDateTime_GET_YEAR(obj);
m = PyDateTime_GET_MONTH(obj);
d = PyDateTime_GET_DAY(obj);
date = PyDate_FromDate(y, m, 1);
ord = PyObject_CallMethod(date, "toordinal", NULL);
days = PyInt_AS_LONG(ord) - EPOCH_ORD + d - 1;
Py_DECREF(date);
Py_DECREF(ord);
*( (JSINT64 *) outValue) = ((JSINT64) days * 86400);
return NULL;
}
int Tuple_iterNext(JSOBJ obj, JSONTypeContext *tc)
{
PyObject *item;
if (GET_TC(tc)->index >= GET_TC(tc)->size)
{
return 0;
}
item = PyTuple_GET_ITEM (obj, GET_TC(tc)->index);
GET_TC(tc)->itemValue = item;
GET_TC(tc)->index ++;
return 1;
}
void Tuple_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{
}
JSOBJ Tuple_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
{
return GET_TC(tc)->itemValue;
}
char *Tuple_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
{
return NULL;
}
int List_iterNext(JSOBJ obj, JSONTypeContext *tc)
{
if (GET_TC(tc)->index >= GET_TC(tc)->size)
{
PRINTMARK();
return 0;
}
GET_TC(tc)->itemValue = PyList_GET_ITEM (obj, GET_TC(tc)->index);
GET_TC(tc)->index ++;
return 1;
}
void List_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{
}
JSOBJ List_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
{
return GET_TC(tc)->itemValue;
}
char *List_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
{
return NULL;
}
//=============================================================================
// Dict iteration functions
// itemName might converted to string (Python_Str). Do refCounting
// itemValue is borrowed from object (which is dict). No refCounting
//=============================================================================
int Dict_iterNext(JSOBJ obj, JSONTypeContext *tc)
{
PyObject* itemNameTmp;
if (GET_TC(tc)->itemName)
{
Py_DECREF(GET_TC(tc)->itemName);
GET_TC(tc)->itemName = NULL;
}
if (!(GET_TC(tc)->itemName = PyIter_Next(GET_TC(tc)->iterator)))
{
PRINTMARK();
return 0;
}
if (GET_TC(tc)->itemValue) {
Py_DECREF(GET_TC(tc)->itemValue);
GET_TC(tc)->itemValue = NULL;
}
if (!(GET_TC(tc)->itemValue = PyObject_GetItem(GET_TC(tc)->dictObj, GET_TC(tc)->itemName))) {
PRINTMARK();
return 0;
}
if (PyUnicode_Check(GET_TC(tc)->itemName))
{
itemNameTmp = GET_TC(tc)->itemName;
GET_TC(tc)->itemName = PyUnicode_AsUTF8String (itemNameTmp);
Py_DECREF(itemNameTmp);
}
else
if (!PyString_Check(GET_TC(tc)->itemName))
{
itemNameTmp = GET_TC(tc)->itemName;
GET_TC(tc)->itemName = PyObject_Str(itemNameTmp);
Py_DECREF(itemNameTmp);
#if PY_MAJOR_VERSION >= 3
itemNameTmp = GET_TC(tc)->itemName;
GET_TC(tc)->itemName = PyUnicode_AsUTF8String (itemNameTmp);
Py_DECREF(itemNameTmp);
#endif
}
PRINTMARK();
return 1;
}
void Dict_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{
if (GET_TC(tc)->itemName) {
Py_DECREF(GET_TC(tc)->itemName);
GET_TC(tc)->itemName = NULL;
}
if (GET_TC(tc)->itemValue) {
Py_DECREF(GET_TC(tc)->itemValue);
GET_TC(tc)->itemValue = NULL;
}
Py_CLEAR(GET_TC(tc)->iterator);
Py_DECREF(GET_TC(tc)->dictObj);
PRINTMARK();
}
JSOBJ Dict_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
{
return GET_TC(tc)->itemValue;
}
char *Dict_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
{
*outLen = PyString_GET_SIZE(GET_TC(tc)->itemName);
return PyString_AS_STRING(GET_TC(tc)->itemName);
}
int SortedDict_iterNext(JSOBJ obj, JSONTypeContext *tc)
{
PyObject *items = NULL, *item = NULL, *key = NULL, *value = NULL;
Py_ssize_t i, nitems;
#if PY_MAJOR_VERSION >= 3
PyObject* keyTmp;
#endif
// Upon first call, obtain a list of the keys and sort them. This follows the same logic as the
// stanard library's _json.c sort_keys handler.
if (GET_TC(tc)->newObj == NULL)
{
// Obtain the list of keys from the dictionary.
items = PyMapping_Keys(GET_TC(tc)->dictObj);
if (items == NULL)
{
goto error;
}
else if (!PyList_Check(items))
{
PyErr_SetString(PyExc_ValueError, "keys must return list");
goto error;
}
// Sort the list.
if (PyList_Sort(items) < 0)
{
goto error;
}
// Obtain the value for each key, and pack a list of (key, value) 2-tuples.
nitems = PyList_GET_SIZE(items);
for (i = 0; i < nitems; i++)
{
key = PyList_GET_ITEM(items, i);
value = PyDict_GetItem(GET_TC(tc)->dictObj, key);
// Subject the key to the same type restrictions and conversions as in Dict_iterGetValue.
if (PyUnicode_Check(key))
{
key = PyUnicode_AsUTF8String(key);
}
else if (!PyString_Check(key))
{
key = PyObject_Str(key);
#if PY_MAJOR_VERSION >= 3
keyTmp = key;
key = PyUnicode_AsUTF8String(key);
Py_DECREF(keyTmp);
#endif
}
else
{
Py_INCREF(key);
}
item = PyTuple_Pack(2, key, value);
if (item == NULL)
{
goto error;
}
if (PyList_SetItem(items, i, item))
{
goto error;
}
Py_DECREF(key);
}
// Store the sorted list of tuples in the newObj slot.
GET_TC(tc)->newObj = items;
GET_TC(tc)->size = nitems;
}
if (GET_TC(tc)->index >= GET_TC(tc)->size)
{
PRINTMARK();
return 0;
}
item = PyList_GET_ITEM(GET_TC(tc)->newObj, GET_TC(tc)->index);
GET_TC(tc)->itemName = PyTuple_GET_ITEM(item, 0);
GET_TC(tc)->itemValue = PyTuple_GET_ITEM(item, 1);
GET_TC(tc)->index++;
return 1;
error:
Py_XDECREF(item);
Py_XDECREF(key);
Py_XDECREF(value);
Py_XDECREF(items);
return -1;
}
void SortedDict_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{
GET_TC(tc)->itemName = NULL;
GET_TC(tc)->itemValue = NULL;
Py_DECREF(GET_TC(tc)->newObj);
Py_DECREF(GET_TC(tc)->dictObj);
PRINTMARK();
}
JSOBJ SortedDict_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
{
return GET_TC(tc)->itemValue;
}
char *SortedDict_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
{
*outLen = PyString_GET_SIZE(GET_TC(tc)->itemName);
return PyString_AS_STRING(GET_TC(tc)->itemName);
}
void SetupDictIter(PyObject *dictObj, TypeContext *pc, JSONObjectEncoder *enc)
{
pc->dictObj = dictObj;
if (enc->sortKeys)
{
pc->iterEnd = SortedDict_iterEnd;
pc->iterNext = SortedDict_iterNext;
pc->iterGetValue = SortedDict_iterGetValue;
pc->iterGetName = SortedDict_iterGetName;
pc->index = 0;
}
else
{
pc->iterEnd = Dict_iterEnd;
pc->iterNext = Dict_iterNext;
pc->iterGetValue = Dict_iterGetValue;
pc->iterGetName = Dict_iterGetName;
pc->iterator = PyObject_GetIter(dictObj);
}
}
void Object_beginTypeContext (JSOBJ _obj, JSONTypeContext *tc, JSONObjectEncoder *enc)
{
PyObject *obj, *objRepr, *exc;
TypeContext *pc;
PRINTMARK();
if (!_obj)
{
tc->type = JT_INVALID;
return;
}
obj = (PyObject*) _obj;
tc->prv = PyObject_Malloc(sizeof(TypeContext));
pc = (TypeContext *) tc->prv;
if (!pc)
{
tc->type = JT_INVALID;
PyErr_NoMemory();
return;
}
pc->newObj = NULL;
pc->dictObj = NULL;
pc->itemValue = NULL;
pc->itemName = NULL;
pc->iterator = NULL;
pc->attrList = NULL;
pc->index = 0;
pc->size = 0;
pc->longValue = 0;
pc->rawJSONValue = NULL;
if (PyIter_Check(obj))
{
PRINTMARK();
goto ISITERABLE;
}
if (PyBool_Check(obj))
{
PRINTMARK();
tc->type = (obj == Py_True) ? JT_TRUE : JT_FALSE;
return;
}
else
if (PyLong_Check(obj))
{
PRINTMARK();
pc->PyTypeToJSON = PyLongToINT64;
tc->type = JT_LONG;
GET_TC(tc)->longValue = PyLong_AsLongLong(obj);
exc = PyErr_Occurred();
if (!exc)
{
return;
}
if (exc && PyErr_ExceptionMatches(PyExc_OverflowError))
{
PyErr_Clear();
pc->PyTypeToJSON = PyLongToUINT64;
tc->type = JT_ULONG;
GET_TC(tc)->unsignedLongValue = PyLong_AsUnsignedLongLong(obj);
exc = PyErr_Occurred();
if (exc && PyErr_ExceptionMatches(PyExc_OverflowError))
{
PRINTMARK();
goto INVALID;
}
}
return;
}
else
if (PyInt_Check(obj))
{
PRINTMARK();
#ifdef _LP64
pc->PyTypeToJSON = PyIntToINT64; tc->type = JT_LONG;
#else
pc->PyTypeToJSON = PyIntToINT32; tc->type = JT_INT;
#endif
return;
}
else
if (PyString_Check(obj) && !PyObject_HasAttrString(obj, "__json__"))
{
PRINTMARK();
pc->PyTypeToJSON = PyStringToUTF8; tc->type = JT_UTF8;
return;
}
else
if (PyUnicode_Check(obj))
{
PRINTMARK();
pc->PyTypeToJSON = PyUnicodeToUTF8; tc->type = JT_UTF8;
return;
}
else
if (PyFloat_Check(obj) || (type_decimal && PyObject_IsInstance(obj, type_decimal)))
{
PRINTMARK();
pc->PyTypeToJSON = PyFloatToDOUBLE; tc->type = JT_DOUBLE;
return;
}
else
if (PyDateTime_Check(obj))
{
PRINTMARK();
pc->PyTypeToJSON = PyDateTimeToINT64; tc->type = JT_LONG;
return;
}
else
if (PyDate_Check(obj))
{
PRINTMARK();
pc->PyTypeToJSON = PyDateToINT64; tc->type = JT_LONG;
return;
}
else
if (obj == Py_None)
{
PRINTMARK();
tc->type = JT_NULL;
return;
}
ISITERABLE:
if (PyDict_Check(obj))
{
PRINTMARK();
tc->type = JT_OBJECT;
SetupDictIter(obj, pc, enc);
Py_INCREF(obj);
return;
}
else
if (PyList_Check(obj))
{
PRINTMARK();
tc->type = JT_ARRAY;
pc->iterEnd = List_iterEnd;
pc->iterNext = List_iterNext;
pc->iterGetValue = List_iterGetValue;
pc->iterGetName = List_iterGetName;
GET_TC(tc)->index = 0;
GET_TC(tc)->size = PyList_GET_SIZE( (PyObject *) obj);
return;
}
else
if (PyTuple_Check(obj))
{
PRINTMARK();
tc->type = JT_ARRAY;
pc->iterEnd = Tuple_iterEnd;
pc->iterNext = Tuple_iterNext;
pc->iterGetValue = Tuple_iterGetValue;
pc->iterGetName = Tuple_iterGetName;
GET_TC(tc)->index = 0;
GET_TC(tc)->size = PyTuple_GET_SIZE( (PyObject *) obj);
GET_TC(tc)->itemValue = NULL;
return;
}
if (PyObject_HasAttrString(obj, "toDict"))
{
PyObject* toDictFunc = PyObject_GetAttrString(obj, "toDict");
PyObject* tuple = PyTuple_New(0);
PyObject* toDictResult = PyObject_Call(toDictFunc, tuple, NULL);
Py_DECREF(tuple);
Py_DECREF(toDictFunc);
if (toDictResult == NULL)
{
goto INVALID;
}
if (!PyDict_Check(toDictResult))
{
Py_DECREF(toDictResult);
tc->type = JT_NULL;
return;
}
PRINTMARK();
tc->type = JT_OBJECT;
SetupDictIter(toDictResult, pc, enc);
return;
}
else
if (PyObject_HasAttrString(obj, "__json__"))
{
PyObject* toJSONFunc = PyObject_GetAttrString(obj, "__json__");
PyObject* tuple = PyTuple_New(0);
PyObject* toJSONResult = PyObject_Call(toJSONFunc, tuple, NULL);
Py_DECREF(tuple);
Py_DECREF(toJSONFunc);
if (toJSONResult == NULL)
{
goto INVALID;
}
if (PyErr_Occurred())
{
Py_DECREF(toJSONResult);
goto INVALID;
}
if (!PyString_Check(toJSONResult) && !PyUnicode_Check(toJSONResult))
{
Py_DECREF(toJSONResult);
PyErr_Format (PyExc_TypeError, "expected string");
goto INVALID;
}
PRINTMARK();
pc->PyTypeToJSON = PyRawJSONToUTF8;
tc->type = JT_RAW;
GET_TC(tc)->rawJSONValue = toJSONResult;
return;
}
PRINTMARK();
PyErr_Clear();
objRepr = PyObject_Repr(obj);
#if PY_MAJOR_VERSION >= 3
PyObject* str = PyUnicode_AsEncodedString(objRepr, "utf-8", "~E~");
PyErr_Format (PyExc_TypeError, "%s is not JSON serializable", PyString_AS_STRING(str));
Py_XDECREF(str);
#else
PyErr_Format (PyExc_TypeError, "%s is not JSON serializable", PyString_AS_STRING(objRepr));
#endif
Py_DECREF(objRepr);
INVALID:
PRINTMARK();
tc->type = JT_INVALID;
PyObject_Free(tc->prv);
tc->prv = NULL;
return;
}
void Object_endTypeContext(JSOBJ obj, JSONTypeContext *tc)
{
Py_XDECREF(GET_TC(tc)->newObj);
if (tc->type == JT_RAW)
{
Py_XDECREF(GET_TC(tc)->rawJSONValue);
}
PyObject_Free(tc->prv);
tc->prv = NULL;
}
const char *Object_getStringValue(JSOBJ obj, JSONTypeContext *tc, size_t *_outLen)
{
return GET_TC(tc)->PyTypeToJSON (obj, tc, NULL, _outLen);
}
JSINT64 Object_getLongValue(JSOBJ obj, JSONTypeContext *tc)
{
JSINT64 ret;
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}
JSUINT64 Object_getUnsignedLongValue(JSOBJ obj, JSONTypeContext *tc)
{
JSUINT64 ret;
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}
JSINT32 Object_getIntValue(JSOBJ obj, JSONTypeContext *tc)
{
JSINT32 ret;
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}
double Object_getDoubleValue(JSOBJ obj, JSONTypeContext *tc)
{
double ret;
GET_TC(tc)->PyTypeToJSON (obj, tc, &ret, NULL);
return ret;
}
static void Object_releaseObject(JSOBJ _obj)
{
Py_DECREF( (PyObject *) _obj);
}
int Object_iterNext(JSOBJ obj, JSONTypeContext *tc)
{
return GET_TC(tc)->iterNext(obj, tc);
}
void Object_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{
GET_TC(tc)->iterEnd(obj, tc);
}
JSOBJ Object_iterGetValue(JSOBJ obj, JSONTypeContext *tc)
{
return GET_TC(tc)->iterGetValue(obj, tc);
}
char *Object_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)
{
return GET_TC(tc)->iterGetName(obj, tc, outLen);
}
PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs)
{
static char *kwlist[] = { "obj", "ensure_ascii", "double_precision", "encode_html_chars", "escape_forward_slashes", "sort_keys", "indent", NULL };
char buffer[65536];
char *ret;
PyObject *newobj;
PyObject *oinput = NULL;
PyObject *oensureAscii = NULL;
PyObject *oencodeHTMLChars = NULL;
PyObject *oescapeForwardSlashes = NULL;
PyObject *osortKeys = NULL;
JSONObjectEncoder encoder =
{
Object_beginTypeContext,
Object_endTypeContext,
Object_getStringValue,
Object_getLongValue,
Object_getUnsignedLongValue,
Object_getIntValue,
Object_getDoubleValue,
Object_iterNext,
Object_iterEnd,
Object_iterGetValue,
Object_iterGetName,
Object_releaseObject,
PyObject_Malloc,
PyObject_Realloc,
PyObject_Free,
-1, //recursionMax
10, // default double precision setting
1, //forceAscii
0, //encodeHTMLChars
1, //escapeForwardSlashes
0, //sortKeys
0, //indent
NULL, //prv
};
PRINTMARK();
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OiOOOi", kwlist, &oinput, &oensureAscii, &encoder.doublePrecision, &oencodeHTMLChars, &oescapeForwardSlashes, &osortKeys, &encoder.indent))
{
return NULL;
}
if (oensureAscii != NULL && !PyObject_IsTrue(oensureAscii))
{
encoder.forceASCII = 0;
}
if (oencodeHTMLChars != NULL && PyObject_IsTrue(oencodeHTMLChars))
{
encoder.encodeHTMLChars = 1;
}
if (oescapeForwardSlashes != NULL && !PyObject_IsTrue(oescapeForwardSlashes))
{
encoder.escapeForwardSlashes = 0;
}
if (osortKeys != NULL && PyObject_IsTrue(osortKeys))
{
encoder.sortKeys = 1;
}
PRINTMARK();
ret = JSON_EncodeObject (oinput, &encoder, buffer, sizeof (buffer));
PRINTMARK();
if (PyErr_Occurred())
{
return NULL;
}
if (encoder.errorMsg)
{
if (ret != buffer)
{
encoder.free (ret);
}
PyErr_Format (PyExc_OverflowError, "%s", encoder.errorMsg);
return NULL;
}
newobj = PyString_FromString (ret);
if (ret != buffer)
{
encoder.free (ret);
}
PRINTMARK();
return newobj;
}
PyObject* objToJSONFile(PyObject* self, PyObject *args, PyObject *kwargs)
{
PyObject *data;
PyObject *file;
PyObject *string;
PyObject *write;
PyObject *argtuple;
PyObject *write_result;
PRINTMARK();
if (!PyArg_ParseTuple (args, "OO", &data, &file))
{
return NULL;
}
if (!PyObject_HasAttrString (file, "write"))
{
PyErr_Format (PyExc_TypeError, "expected file");
return NULL;
}
write = PyObject_GetAttrString (file, "write");
if (!PyCallable_Check (write))
{
Py_XDECREF(write);
PyErr_Format (PyExc_TypeError, "expected file");
return NULL;
}
argtuple = PyTuple_Pack(1, data);
string = objToJSON (self, argtuple, kwargs);
if (string == NULL)
{
Py_XDECREF(write);
Py_XDECREF(argtuple);
return NULL;
}
Py_XDECREF(argtuple);
argtuple = PyTuple_Pack (1, string);
if (argtuple == NULL)
{
Py_XDECREF(write);
return NULL;
}
write_result = PyObject_CallObject (write, argtuple);
if (write_result == NULL)
{
Py_XDECREF(write);
Py_XDECREF(argtuple);
return NULL;
}
Py_DECREF(write_result);
Py_XDECREF(write);
Py_DECREF(argtuple);
Py_XDECREF(string);
PRINTMARK();
Py_RETURN_NONE;
}
srsly-release-v2.5.1/srsly/ujson/py_defines.h 0000664 0000000 0000000 00000004475 14742310675 0021332 0 ustar 00root root 0000000 0000000 /*
Developed by ESN, an Electronic Arts Inc. studio.
Copyright (c) 2014, Electronic Arts Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of ESN, Electronic Arts Inc. nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
http://code.google.com/p/stringencoders/
Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
Numeric decoder derived from from TCL library
http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
* Copyright (c) 1988-1993 The Regents of the University of California.
* Copyright (c) 1994 Sun Microsystems, Inc.
*/
#include
#if PY_MAJOR_VERSION >= 3
#define PyInt_Check PyLong_Check
#define PyInt_AS_LONG PyLong_AsLong
#define PyInt_FromLong PyLong_FromLong
#define PyString_Check PyBytes_Check
#define PyString_GET_SIZE PyBytes_GET_SIZE
#define PyString_AS_STRING PyBytes_AS_STRING
#define PyString_FromString PyUnicode_FromString
#endif
srsly-release-v2.5.1/srsly/ujson/ujson.c 0000664 0000000 0000000 00000011214 14742310675 0020323 0 ustar 00root root 0000000 0000000 /*
Developed by ESN, an Electronic Arts Inc. studio.
Copyright (c) 2014, Electronic Arts Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of ESN, Electronic Arts Inc. nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
http://code.google.com/p/stringencoders/
Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
Numeric decoder derived from from TCL library
http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
* Copyright (c) 1988-1993 The Regents of the University of California.
* Copyright (c) 1994 Sun Microsystems, Inc.
*/
#include "py_defines.h"
#include "version.h"
/* objToJSON */
PyObject* objToJSON(PyObject* self, PyObject *args, PyObject *kwargs);
void initObjToJSON(void);
/* JSONToObj */
PyObject* JSONToObj(PyObject* self, PyObject *args, PyObject *kwargs);
/* objToJSONFile */
PyObject* objToJSONFile(PyObject* self, PyObject *args, PyObject *kwargs);
/* JSONFileToObj */
PyObject* JSONFileToObj(PyObject* self, PyObject *args, PyObject *kwargs);
#define ENCODER_HELP_TEXT "Use ensure_ascii=false to output UTF-8. Pass in double_precision to alter the maximum digit precision of doubles. Set encode_html_chars=True to encode < > & as unicode escape sequences. Set escape_forward_slashes=False to prevent escaping / characters."
static PyMethodDef ujsonMethods[] = {
{"encode", (PyCFunction) objToJSON, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON. " ENCODER_HELP_TEXT},
{"decode", (PyCFunction) JSONToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as string to dict object structure. Use precise_float=True to use high precision float decoder."},
{"dumps", (PyCFunction) objToJSON, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON. " ENCODER_HELP_TEXT},
{"loads", (PyCFunction) JSONToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as string to dict object structure. Use precise_float=True to use high precision float decoder."},
{"dump", (PyCFunction) objToJSONFile, METH_VARARGS | METH_KEYWORDS, "Converts arbitrary object recursively into JSON file. " ENCODER_HELP_TEXT},
{"load", (PyCFunction) JSONFileToObj, METH_VARARGS | METH_KEYWORDS, "Converts JSON as file to dict object structure. Use precise_float=True to use high precision float decoder."},
{NULL, NULL, 0, NULL} /* Sentinel */
};
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"ujson",
0, /* m_doc */
-1, /* m_size */
ujsonMethods, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear */
NULL /* m_free */
};
#define PYMODINITFUNC PyObject *PyInit_ujson(void)
#define PYMODULE_CREATE() PyModule_Create(&moduledef)
#define MODINITERROR return NULL
#else
#define PYMODINITFUNC PyMODINIT_FUNC initujson(void)
#define PYMODULE_CREATE() Py_InitModule("ujson", ujsonMethods)
#define MODINITERROR return
#endif
PYMODINITFUNC
{
PyObject *module;
PyObject *version_string;
initObjToJSON();
module = PYMODULE_CREATE();
if (module == NULL)
{
MODINITERROR;
}
version_string = PyString_FromString (UJSON_VERSION);
PyModule_AddObject (module, "__version__", version_string);
#if PY_MAJOR_VERSION >= 3
return module;
#endif
}
srsly-release-v2.5.1/srsly/ujson/version.h 0000664 0000000 0000000 00000003717 14742310675 0020670 0 ustar 00root root 0000000 0000000 /*
Developed by ESN, an Electronic Arts Inc. studio.
Copyright (c) 2014, Electronic Arts Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of ESN, Electronic Arts Inc. nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL ELECTRONIC ARTS INC. BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Portions of code from MODP_ASCII - Ascii transformations (upper/lower, etc)
http://code.google.com/p/stringencoders/
Copyright (c) 2007 Nick Galbreath -- nickg [at] modp [dot] com. All rights reserved.
Numeric decoder derived from from TCL library
http://www.opensource.apple.com/source/tcl/tcl-14/tcl/license.terms
* Copyright (c) 1988-1993 The Regents of the University of California.
* Copyright (c) 1994 Sun Microsystems, Inc.
*/
#define UJSON_VERSION "1.35"
srsly-release-v2.5.1/srsly/util.py 0000664 0000000 0000000 00000002124 14742310675 0017212 0 ustar 00root root 0000000 0000000 from pathlib import Path
from typing import Union, Dict, Any, List, Tuple
from collections import OrderedDict
# fmt: off
FilePath = Union[str, Path]
# Superficial JSON input/output types
# https://github.com/python/typing/issues/182#issuecomment-186684288
JSONOutput = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
JSONOutputBin = Union[bytes, str, int, float, bool, None, Dict[str, Any], List[Any]]
# For input, we also accept tuples, ordered dicts etc.
JSONInput = Union[str, int, float, bool, None, Dict[str, Any], List[Any], Tuple[Any, ...], OrderedDict]
JSONInputBin = Union[bytes, str, int, float, bool, None, Dict[str, Any], List[Any], Tuple[Any, ...], OrderedDict]
YAMLInput = JSONInput
YAMLOutput = JSONOutput
# fmt: on
def force_path(location, require_exists=True):
if not isinstance(location, Path):
location = Path(location)
if require_exists and not location.exists():
raise ValueError(f"Can't read file: {location}")
return location
def force_string(location):
if isinstance(location, str):
return location
return str(location)