pax_global_header 0000666 0000000 0000000 00000000064 14551576331 0014524 g ustar 00root root 0000000 0000000 52 comment=8abdd60755e5b6ebfa2b383cf6eaa0d9315335bd
abseil-py-2.1.0/ 0000775 0000000 0000000 00000000000 14551576331 0013411 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/.github/ 0000775 0000000 0000000 00000000000 14551576331 0014751 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/.github/workflows/ 0000775 0000000 0000000 00000000000 14551576331 0017006 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/.github/workflows/test.yml 0000664 0000000 0000000 00000002517 14551576331 0020515 0 ustar 00root root 0000000 0000000 name: Test
on: [push, pull_request]
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.run_id }}
cancel-in-progress: true
jobs:
test:
if:
github.event_name == 'push' || github.event.pull_request.head.repo.full_name !=
github.repository
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
os: [ubuntu-latest, macOS-latest, windows-latest]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
id: setup_python
with:
python-version: ${{ matrix.python-version }}
allow-prereleases: true
- name: Install virtualenv
run: |
python -m pip install --upgrade pip
python -m pip install --upgrade virtualenv
- name: Run tests
env:
ABSL_EXPECTED_PYTHON_VERSION: ${{ matrix.python-version }}
ABSL_COPY_TESTLOGS_TO: ci-artifacts
shell: bash
run: ci/run_tests.sh
- name: Upload bazel test logs
uses: actions/upload-artifact@v3
with:
name: bazel-testlogs-${{ matrix.os }}-${{ matrix.python-version }}
path: ci-artifacts
abseil-py-2.1.0/.readthedocs.yaml 0000664 0000000 0000000 00000000743 14551576331 0016644 0 ustar 00root root 0000000 0000000 # .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Set the version of Python and other tools you might need
python:
version: "3"
install:
- method: pip
path: .
extra_requirements:
- m2r2
- sphinxcontrib-apidoc
# Build documentation in the docs/ directory with Sphinx
sphinx:
builder: html
configuration: docs/source/conf.py
abseil-py-2.1.0/AUTHORS 0000664 0000000 0000000 00000000450 14551576331 0014460 0 ustar 00root root 0000000 0000000 # This is the list of Abseil authors for copyright purposes.
#
# This does not necessarily list everyone who has contributed code, since in
# some cases, their employer may be the copyright holder. To see the full list
# of contributors, see the revision history in source control.
Google Inc.
abseil-py-2.1.0/BUILD.bazel 0000664 0000000 0000000 00000001270 14551576331 0015267 0 ustar 00root root 0000000 0000000 # Copyright 2021 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
package(default_visibility = ["//visibility:public"])
licenses(["notice"])
exports_files([
"LICENSE",
])
abseil-py-2.1.0/CHANGELOG.md 0000664 0000000 0000000 00000037566 14551576331 0015243 0 ustar 00root root 0000000 0000000 # Python Absl Changelog
All notable changes to Python Absl are recorded here.
The format is based on [Keep a Changelog](https://keepachangelog.com).
## Unreleased
Nothing notable unreleased.
## 2.1.0 (2024-01-16)
### Added
* (flags) Added `absl.flags.override_value` function to provide `FlagHolder`
with a construct to modify values. The new interface parallels
`absl.flags.FlagValues.__setattr__` but checks that the provided value
conforms to the flag's expected type.
* (testing) Added a new method `absltest.TestCase.assertDataclassEqual` that
tests equality of `dataclass.dataclass` objects with better error messages
when the assert fails.
### Changed
* (flags) `absl.flags.argparse_flags.ArgumentParser` now correctly inherits
an empty instance of `FlagValues` to ensure that absl flags, such as
`--flagfile`, `--undefok` are supported.
* (testing) Do not exit 5 if tests were skipped on Python 3.12. This follows
the CPython change in https://github.com/python/cpython/pull/113856.
### Fixed
* (flags) The flag `foo` no longer retains the value `bar` after
`FLAGS.foo = bar` fails due to a validation error.
* (testing) Fixed an issue caused by
[this Python 3.12.1 change](https://github.com/python/cpython/pull/109725)
where the test reporter crashes when all tests are skipped.
## 2.0.0 (2023-09-19)
### Changed
* `absl-py` no longer supports Python 3.6. It has reached end-of-life for more
than a year now.
* Support Python 3.12.
* (logging) `logging.exception` can now take `exc_info` as argument, with
default value `True`. Prior to this change setting `exc_info` would raise
`KeyError`, this change fixes this behaviour.
* (testing) For Python 3.11+, the calls to `absltest.TestCase.enter_context`
are forwarded to `unittest.TestCase.enterContext` (when called via instance)
or `unittest.TestCase.enterClassContext` (when called via class) now. As a
result, on Python 3.11+, the private `_cls_exit_stack` attribute is not
defined on `absltest.TestCase` and `_exit_stack` attribute is not defined on
its instances.
* (testing) `absltest.TestCase.assertSameStructure()` now uses the test case's
equality functions (registered with `TestCase.addTypeEqualityFunc()`) for
comparing leaves of the structure.
* (testing) `abslTest.TestCase.fail()` now names its arguments `(self,
msg=None, user_msg=None)`, and not `(self, msg=None, prefix=None)`, better
reflecting the behavior and usage of the two message arguments.
* `DEFINE_enum`, `DEFINE_multi_enum`, and `EnumParser` now raise errors when
`enum_values` is provided as a single string value. Additionally,
`EnumParser.enum_values` is now stored as a list copy of the provided
`enum_values` parameter.
* (testing) Updated `paramaterized.CoopTestCase()` to use Python 3 metaclass
idioms. Most uses of this function continued working during the Python 3
migration still worked because a Python 2 compatibility `__metaclass__`
variables also existed. Now pure Python 3 base classes without backwards
compatibility will work as intended.
* (testing) `absltest.TestCase.assertSequenceStartsWith` now explicitly fail
when passed a `Mapping` or `Set` object as the `whole` argument.
## 1.4.0 (2023-01-11)
### New
* (testing) Added `@flagsaver.as_parsed`: this allows saving/restoring flags
using string values as if parsed from the command line and will also reflect
other flag states after command line parsing, e.g. `.present` is set.
### Changed
* (logging) If no log dir is specified `logging.find_log_dir()` now falls back
to `tempfile.gettempdir()` instead of `/tmp/`.
### Fixed
* (flags) Additional kwargs (e.g. `short_name=`) to `DEFINE_multi_enum_class`
are now correctly passed to the underlying `Flag` object.
## 1.3.0 (2022-10-11)
### Added
* (flags) Added a new `absl.flags.set_default` function that updates the flag
default for a provided `FlagHolder`. This parallels the
`absl.flags.FlagValues.set_default` interface which takes a flag name.
* (flags) The following functions now also accept `FlagHolder` instance(s) in
addition to flag name(s) as their first positional argument:
- `flags.register_validator`
- `flags.validator`
- `flags.register_multi_flags_validator`
- `flags.multi_flags_validator`
- `flags.mark_flag_as_required`
- `flags.mark_flags_as_required`
- `flags.mark_flags_as_mutual_exclusive`
- `flags.mark_bool_flags_as_mutual_exclusive`
- `flags.declare_key_flag`
### Changed
* (testing) Assertions `assertRaisesWithPredicateMatch` and
`assertRaisesWithLiteralMatch` now capture the raised `Exception` for
further analysis when used as a context manager.
* (testing) TextAndXMLTestRunner now produces time duration values with
millisecond precision in XML test result output.
* (flags) Keyword access to `flag_name` arguments in the following functions
is deprecated. This parameter will be renamed in a future 2.0.0 release.
- `flags.register_validator`
- `flags.validator`
- `flags.register_multi_flags_validator`
- `flags.multi_flags_validator`
- `flags.mark_flag_as_required`
- `flags.mark_flags_as_required`
- `flags.mark_flags_as_mutual_exclusive`
- `flags.mark_bool_flags_as_mutual_exclusive`
- `flags.declare_key_flag`
## 1.2.0 (2022-07-18)
### Fixed
* Fixed a crash in Python 3.11 when `TempFileCleanup.SUCCESS` is used.
## 1.1.0 (2022-06-01)
* `Flag` instances now raise an error if used in a bool context. This prevents
the occasional mistake of testing an instance for truthiness rather than
testing `flag.value`.
* `absl-py` no longer depends on `six`.
## 1.0.0 (2021-11-09)
### Changed
* `absl-py` no longer supports Python 2.7, 3.4, 3.5. All versions have reached
end-of-life for more than a year now.
* New releases will be tagged as `vX.Y.Z` instead of `pypi-vX.Y.Z` in the git
repo going forward.
## 0.15.0 (2021-10-19)
### Changed
* (testing) #128: When running bazel with its `--test_filter=` flag, it now
treats the filters as `unittest`'s `-k` flag in Python 3.7+.
## 0.14.1 (2021-09-30)
### Fixed
* Top-level `LICENSE` file is now exported in bazel.
## 0.14.0 (2021-09-21)
### Fixed
* #171: Creating `argparse_flags.ArgumentParser` with `argument_default=` no
longer raises an exception when other `absl.flags` flags are defined.
* #173: `absltest` now correctly sets up test filtering and fail fast flags
when an explicit `argv=` parameter is passed to `absltest.main`.
## 0.13.0 (2021-06-14)
### Added
* (app) Type annotations for public `app` interfaces.
* (testing) Added new decorator `@absltest.skipThisClass` to indicate a class
contains shared functionality to be used as a base class for other
TestCases, and therefore should be skipped.
### Changed
* (app) Annotated the `flag_parser` paramteter of `run` as keyword-only. This
keyword-only constraint will be enforced at runtime in a future release.
* (app, flags) Flag validations now include all errors from disjoint flag
sets, instead of fail fast upon first error from all validators. Multiple
validators on the same flag still fails fast.
## 0.12.0 (2021-03-08)
### Added
* (flags) Made `EnumClassSerializer` and `EnumClassListSerializer` public.
* (flags) Added a `required: Optional[bool] = False` parameter to `DEFINE_*`
functions.
* (testing) flagsaver overrides can now be specified in terms of FlagHolder.
* (testing) `parameterized.product`: Allows testing a method over cartesian
product of parameters values, specified as a sequences of values for each
parameter or as kwargs-like dicts of parameter values.
* (testing) Added public flag holders for `--test_srcdir` and `--test_tmpdir`.
Users should use `absltest.TEST_SRCDIR.value` and
`absltest.TEST_TMPDIR.value` instead of `FLAGS.test_srcdir` and
`FLAGS.test_tmpdir`.
### Fixed
* (flags) Made `CsvListSerializer` respect its delimiter argument.
## 0.11.0 (2020-10-27)
### Changed
* (testing) Surplus entries in AssertionError stack traces from absltest are
now suppressed and no longer reported in the xml_reporter.
* (logging) An exception is now raised instead of `logging.fatal` when logging
directories cannot be found.
* (testing) Multiple flags are now set together before their validators run.
This resolves an issue where multi-flag validators rely on specific flag
combinations.
* (flags) As a deterrent for misuse, FlagHolder objects will now raise a
TypeError exception when used in a conditional statement or equality
expression.
## 0.10.0 (2020-08-19)
### Added
* (testing) `_TempDir` and `_TempFile` now implement `__fspath__` to satisfy
`os.PathLike`
* (logging) `--logger_levels`: allows specifying the log levels of loggers.
* (flags) `FLAGS.validate_all_flags`: a new method that validates all flags
and raises an exception if one fails.
* (flags) `FLAGS.get_flags_for_module`: Allows fetching the flags a module
defines.
* (testing) `parameterized.TestCase`: Supports async test definitions.
* (testing,app) Added `--pdb` flag: When true, uncaught exceptions will be
handled by `pdb.post_mortem`. This is an alias for `--pdb_post_mortem`.
### Changed
* (testing) Failed tests output a copy/pastable test id to make it easier to
copy the failing test to the command line.
* (testing) `@parameterized.parameters` now treats a single `abc.Mapping` as a
single test case, consistent with `named_parameters`. Previously the
`abc.Mapping` is treated as if only its keys are passed as a list of test
cases. If you were relying on the old inconsistent behavior, explicitly
convert the `abc.Mapping` to a `list`.
* (flags) `DEFINE_enum_class` and `DEFINE_mutlti_enum_class` accept a
`case_sensitive` argument. When `False` (the default), strings are mapped to
enum member names without case sensitivity, and member names are serialized
in lowercase form. Flag definitions for enums whose members include
duplicates when case is ignored must now explicitly pass
`case_sensitive=True`.
### Fixed
* (flags) Defining an alias no longer marks the aliased flag as always present
on the command line.
* (flags) Aliasing a multi flag no longer causes the default value to be
appended to.
* (flags) Alias default values now matched the aliased default value.
* (flags) Alias `present` counter now correctly reflects command line usage.
## 0.9.0 (2019-12-17)
### Added
* (testing) `TestCase.enter_context`: Allows using context managers in setUp
and having them automatically exited when a test finishes.
### Fixed
* #126: calling `logging.debug(msg, stack_info=...)` no longer throws an
exception in Python 3.8.
## 0.8.1 (2019-10-08)
### Fixed
* (testing) `absl.testing`'s pretty print reporter no longer buffers
RUN/OK/FAILED messages.
* (testing) `create_tempfile` will overwrite pre-existing read-only files.
## 0.8.0 (2019-08-26)
### Added
* (testing) `absltest.expectedFailureIf`: a variant of
`unittest.expectedFailure` that allows a condition to be given.
### Changed
* (bazel) Tests now pass when bazel
`--incompatible_allow_python_version_transitions=true` is set.
* (bazel) Both Python 2 and Python 3 versions of tests are now created. To
only run one major Python version, use `bazel test
--test_tag_filters=-python[23]` to ignore the other version.
* (testing) `assertTotallyOrdered` no longer requires objects to implement
`__hash__`.
* (testing) `absltest` now integrates better with `--pdb_post_mortem`.
* (testing) `xml_reporter` now includes timestamps to testcases, test_suite,
test_suites elements.
### Fixed
* #99: `absl.logging` no longer registers itself to `logging.root` at import
time.
* #108: Tests now pass with Bazel 0.28.0 on macOS.
## 0.7.1 (2019-03-12)
### Added
* (flags) `flags.mark_bool_flags_as_mutual_exclusive`: convenience function to
check that only one, or at most one, flag among a set of boolean flags are
True.
### Changed
* (bazel) Bazel 0.23+ or 0.22+ is now required for building/testing.
Specifically, a Bazel version that supports
`@bazel_tools//tools/python:python_version` for selecting the Python
version.
### Fixed
* #94: LICENSE files are now included in sdist.
* #93: Change log added.
## 0.7.0 (2019-01-11)
### Added
* (bazel) testonly=1 has been removed from the testing libraries, which allows
their use outside of testing contexts.
* (flags) Multi-flags now accept any Iterable type for the default value
instead of only lists. Strings are still special cased as before. This
allows sets, generators, views, etc to be used naturally.
* (flags) DEFINE_multi_enum_class: a multi flag variant of enum_class.
* (testing) Most of absltest is now type-annotated.
* (testing) Made AbslTest.assertRegex available under Python 2. This allows
Python 2 code to write more natural Python 3 compatible code. (Note: this
was actually released in 0.6.1, but unannounced)
* (logging) logging.vlog_is_on: helper to tell if a vlog() call will actually
log anything. This allows avoiding computing expansive inputs to a logging
call when logging isn't enabled for that level.
### Fixed
* (flags) Pickling flags now raises an clear error instead of a cryptic one.
Pickling flags isn't supported; instead use flags_into_string to serialize
flags.
* (flags) Flags serialization works better: the resulting serialized value,
when deserialized, won't cause --help to be invoked, thus ending the
process.
* (flags) Several flag fixes to make them behave more like the Absl C++ flags:
empty --flagfile is allowed; --nohelp and --help=false don't display help
* (flags) An empty --flagfile value (e.g. "--flagfile=" or "--flagfile=''"
doesn't raise an error; its not just ignored. This matches Abseil C++
behavior.
* (bazel) Building with Bazel 0.2.0 works without extra incompatibility
disable build flags.
### Changed
* (flags) Flag serialization is now deterministic: this improves Bazel build
caching for tools that are affected by flag serialization.
## 0.6.0 (2018-10-22)
### Added
* Tempfile management APIs for tests: read/write/manage tempfiles for test
purposes easily and correctly. See TestCase.create_temp{file/dir} and the
corresponding commit for more info.
## 0.5.0 (2018-09-17)
### Added
* Flags enum support: flags.DEFINE_enum_class allows using an `Enum` derived
class to define the allowed values for a flag.
## 0.4.1 (2018-08-28)
### Fixed
* Flags no long allow spaces in their names
### Changed
* XML test output is written at the end of all test execution.
* If the current user's username can't be gotten, fallback to uid, else fall
back to a generic 'unknown' string.
## 0.4.0 (2018-08-14)
### Added
* argparse integration: absl-registered flags can now be accessed via argparse
using absl.flags.argparse_flags: see that module for more information.
* TestCase.assertSameStructure now allows mixed set types.
### Changed
* Test output now includes start/end markers for each test ran. This is to
help distinguish output from tests clearly.
## 0.3.0 (2018-07-25)
### Added
* `app.call_after_init`: Register functions to be called after app.run() is
called. Useful for program-wide initialization that library code may need.
* `logging.log_every_n_seconds`: like log_every_n, but based on elapsed time
between logging calls.
* `absltest.mock`: alias to unittest.mock (PY3) for better unittest drop-in
replacement. For PY2, it will be available if mock is importable.
### Fixed
* `ABSLLogger.findCaller()`: allow stack_info arg and return value for PY2
* Make stopTest locking reentrant: this prevents deadlocks for test frameworks
that customize unittest.TextTestResult.stopTest.
* Make --helpfull work with unicode flag help strings.
abseil-py-2.1.0/CONTRIBUTING.md 0000664 0000000 0000000 00000006305 14551576331 0015646 0 ustar 00root root 0000000 0000000 # How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
NOTE: If you are new to GitHub, please start by reading the [Pull Request
howto](https://help.github.com/articles/about-pull-requests/).
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution,
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Coding Style
To keep the source consistent, readable, diffable and easy to merge, we use a
fairly rigid coding style, as defined by the
[google-styleguide](https://github.com/google/styleguide) project. All patches
will be expected to conform to the Python style outlined
[here](https://google.github.io/styleguide/pyguide.html).
## Guidelines for Pull Requests
* Create **small PRs** that are narrowly focused on **addressing a single
concern**. We often receive PRs that are trying to fix several things at a
time, but if only one fix is considered acceptable, nothing gets merged and
both author's & review's time is wasted. Create more PRs to address
different concerns and everyone will be happy.
* For speculative changes, consider opening an
[issue](https://github.com/abseil/abseil-py/issues) and discussing it first.
* Provide a good **PR description** as a record of **what** change is being
made and **why** it was made. Link to a GitHub issue if it exists.
* Don't fix code style and formatting unless you are already changing that
line to address an issue. PRs with irrelevant changes won't be merged. If
you do want to fix formatting or style, do that in a separate PR.
* Unless your PR is trivial, you should expect there will be reviewer comments
that you'll need to address before merging. We expect you to be reasonably
responsive to those comments, otherwise the PR will be closed after 2-3
weeks of inactivity.
* Maintain **clean commit history** and use **meaningful commit messages**.
PRs with messy commit history are difficult to review and won't be merged.
Use `rebase -i upstream/main` to curate your commit history and/or to
bring in latest changes from main (but avoid rebasing in the middle of a
code review).
* Keep your PR up to date with upstream/main (if there are merge conflicts,
we can't really merge your change).
* **All tests need to be passing** before your change can be merged. We
recommend you **run tests locally** (see
[Running Tests](README.md#running-tests)).
* Exceptions to the rules can be made if there's a compelling reason for doing
so. That is - the rules are here to serve us, not the other way around, and
the rules need to be serving their intended purpose to be valuable.
* All submissions, including submissions by project members, require review.
abseil-py-2.1.0/LICENSE 0000664 0000000 0000000 00000026136 14551576331 0014426 0 ustar 00root root 0000000 0000000
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
abseil-py-2.1.0/MANIFEST.in 0000664 0000000 0000000 00000000020 14551576331 0015137 0 ustar 00root root 0000000 0000000 include LICENSE
abseil-py-2.1.0/README.md 0000664 0000000 0000000 00000002513 14551576331 0014671 0 ustar 00root root 0000000 0000000 # Abseil Python Common Libraries
This repository is a collection of Python library code for building Python
applications. The code is collected from Google's own Python code base, and has
been extensively tested and used in production.
## Features
* Simple application startup
* Distributed commandline flags system
* Custom logging module with additional features
* Testing utilities
## Getting Started
### Installation
To install the package, simply run:
```bash
pip install absl-py
```
Or install from source:
```bash
python setup.py install
```
### Running Tests
To run Abseil tests, you can clone the git repo and run
[bazel](https://bazel.build/):
```bash
git clone https://github.com/abseil/abseil-py.git
cd abseil-py
bazel test absl/...
```
### Example Code
Please refer to
[smoke_tests/sample_app.py](https://github.com/abseil/abseil-py/blob/main/smoke_tests/sample_app.py)
as an example to get started.
## Documentation
See the [Abseil Python Developer Guide](https://abseil.io/docs/python/).
## Future Releases
The current repository includes an initial set of libraries for early adoption.
More components and interoperability with Abseil C++ Common Libraries
will come in future releases.
## License
The Abseil Python library is licensed under the terms of the Apache
license. See [LICENSE](LICENSE) for more information.
abseil-py-2.1.0/WORKSPACE 0000664 0000000 0000000 00000002030 14551576331 0014665 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
workspace(name = "io_abseil_py")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "rules_python",
sha256 = "863ba0fa944319f7e3d695711427d9ad80ba92c6edd0b7c7443b84e904689539",
strip_prefix = "rules_python-0.22.0",
url = "https://github.com/bazelbuild/rules_python/releases/download/0.22.0/rules_python-0.22.0.tar.gz",
)
load("@rules_python//python:repositories.bzl", "py_repositories")
py_repositories()
abseil-py-2.1.0/absl/ 0000775 0000000 0000000 00000000000 14551576331 0014332 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/absl/BUILD 0000664 0000000 0000000 00000003547 14551576331 0015125 0 ustar 00root root 0000000 0000000 load("@rules_python//python:py_library.bzl", "py_library")
load("@rules_python//python:py_test.bzl", "py_test")
load("@rules_python//python:py_binary.bzl", "py_binary")
licenses(["notice"])
py_library(
name = "app",
srcs = [
"app.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":command_name",
"//absl/flags",
"//absl/logging",
],
)
py_library(
name = "command_name",
srcs = ["command_name.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_library(
name = "tests/app_test_helper",
testonly = 1,
srcs = ["tests/app_test_helper.py"],
srcs_version = "PY2AND3",
deps = [
":app",
"//absl/flags",
],
)
py_binary(
name = "tests/app_test_helper_pure_python",
testonly = 1,
srcs = ["tests/app_test_helper.py"],
main = "tests/app_test_helper.py",
python_version = "PY3",
srcs_version = "PY3",
deps = [
":app",
"//absl/flags",
],
)
py_test(
name = "tests/app_test",
srcs = ["tests/app_test.py"],
data = [":tests/app_test_helper_pure_python"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":app",
":tests/app_test_helper",
"//absl/flags",
"//absl/testing:_bazelize_command",
"//absl/testing:absltest",
"//absl/testing:flagsaver",
],
)
py_test(
name = "tests/command_name_test",
srcs = ["tests/command_name_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":command_name",
"//absl/testing:absltest",
],
)
py_test(
name = "tests/python_version_test",
srcs = ["tests/python_version_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//absl/flags",
"//absl/testing:absltest",
],
)
abseil-py-2.1.0/absl/__init__.py 0000664 0000000 0000000 00000001110 14551576331 0016434 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
abseil-py-2.1.0/absl/app.py 0000664 0000000 0000000 00000036016 14551576331 0015472 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic entry point for Abseil Python applications.
To use this module, define a ``main`` function with a single ``argv`` argument
and call ``app.run(main)``. For example::
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
if __name__ == '__main__':
app.run(main)
"""
import collections
import errno
import os
import pdb
import sys
import textwrap
import traceback
from absl import command_name
from absl import flags
from absl import logging
try:
import faulthandler
except ImportError:
faulthandler = None
FLAGS = flags.FLAGS
flags.DEFINE_boolean('run_with_pdb', False, 'Set to true for PDB debug mode')
flags.DEFINE_boolean('pdb_post_mortem', False,
'Set to true to handle uncaught exceptions with PDB '
'post mortem.')
flags.DEFINE_alias('pdb', 'pdb_post_mortem')
flags.DEFINE_boolean('run_with_profiling', False,
'Set to true for profiling the script. '
'Execution will be slower, and the output format might '
'change over time.')
flags.DEFINE_string('profile_file', None,
'Dump profile information to a file (for python -m '
'pstats). Implies --run_with_profiling.')
flags.DEFINE_boolean('use_cprofile_for_profiling', True,
'Use cProfile instead of the profile module for '
'profiling. This has no effect unless '
'--run_with_profiling is set.')
flags.DEFINE_boolean('only_check_args', False,
'Set to true to validate args and exit.',
allow_hide_cpp=True)
# If main() exits via an abnormal exception, call into these
# handlers before exiting.
EXCEPTION_HANDLERS = []
class Error(Exception):
pass
class UsageError(Error):
"""Exception raised when the arguments supplied by the user are invalid.
Raise this when the arguments supplied are invalid from the point of
view of the application. For example when two mutually exclusive
flags have been supplied or when there are not enough non-flag
arguments. It is distinct from flags.Error which covers the lower
level of parsing and validating individual flags.
"""
def __init__(self, message, exitcode=1):
super(UsageError, self).__init__(message)
self.exitcode = exitcode
class HelpFlag(flags.BooleanFlag):
"""Special boolean flag that displays usage and raises SystemExit."""
NAME = 'help'
SHORT_NAME = '?'
def __init__(self):
super(HelpFlag, self).__init__(
self.NAME, False, 'show this help',
short_name=self.SHORT_NAME, allow_hide_cpp=True)
def parse(self, arg):
if self._parse(arg):
usage(shorthelp=True, writeto_stdout=True)
# Advertise --helpfull on stdout, since usage() was on stdout.
print()
print('Try --helpfull to get a list of all flags.')
sys.exit(1)
class HelpshortFlag(HelpFlag):
"""--helpshort is an alias for --help."""
NAME = 'helpshort'
SHORT_NAME = None
class HelpfullFlag(flags.BooleanFlag):
"""Display help for flags in the main module and all dependent modules."""
def __init__(self):
super(HelpfullFlag, self).__init__(
'helpfull', False, 'show full help', allow_hide_cpp=True)
def parse(self, arg):
if self._parse(arg):
usage(writeto_stdout=True)
sys.exit(1)
class HelpXMLFlag(flags.BooleanFlag):
"""Similar to HelpfullFlag, but generates output in XML format."""
def __init__(self):
super(HelpXMLFlag, self).__init__(
'helpxml', False, 'like --helpfull, but generates XML output',
allow_hide_cpp=True)
def parse(self, arg):
if self._parse(arg):
flags.FLAGS.write_help_in_xml_format(sys.stdout)
sys.exit(1)
def parse_flags_with_usage(args):
"""Tries to parse the flags, print usage, and exit if unparsable.
Args:
args: [str], a non-empty list of the command line arguments including
program name.
Returns:
[str], a non-empty list of remaining command line arguments after parsing
flags, including program name.
"""
try:
return FLAGS(args)
except flags.Error as error:
message = str(error)
if '\n' in message:
final_message = 'FATAL Flags parsing error:\n%s\n' % textwrap.indent(
message, ' ')
else:
final_message = 'FATAL Flags parsing error: %s\n' % message
sys.stderr.write(final_message)
sys.stderr.write('Pass --helpshort or --helpfull to see help on flags.\n')
sys.exit(1)
_define_help_flags_called = False
def define_help_flags():
"""Registers help flags. Idempotent."""
# Use a global to ensure idempotence.
global _define_help_flags_called
if not _define_help_flags_called:
flags.DEFINE_flag(HelpFlag())
flags.DEFINE_flag(HelpshortFlag()) # alias for --help
flags.DEFINE_flag(HelpfullFlag())
flags.DEFINE_flag(HelpXMLFlag())
_define_help_flags_called = True
def _register_and_parse_flags_with_usage(
argv=None,
flags_parser=parse_flags_with_usage,
):
"""Registers help flags, parses arguments and shows usage if appropriate.
This also calls sys.exit(0) if flag --only_check_args is True.
Args:
argv: [str], a non-empty list of the command line arguments including
program name, sys.argv is used if None.
flags_parser: Callable[[List[Text]], Any], the function used to parse flags.
The return value of this function is passed to `main` untouched.
It must guarantee FLAGS is parsed after this function is called.
Returns:
The return value of `flags_parser`. When using the default `flags_parser`,
it returns the following:
[str], a non-empty list of remaining command line arguments after parsing
flags, including program name.
Raises:
Error: Raised when flags_parser is called, but FLAGS is not parsed.
SystemError: Raised when it's called more than once.
"""
if _register_and_parse_flags_with_usage.done:
raise SystemError('Flag registration can be done only once.')
define_help_flags()
original_argv = sys.argv if argv is None else argv
args_to_main = flags_parser(original_argv)
if not FLAGS.is_parsed():
raise Error('FLAGS must be parsed after flags_parser is called.')
# Exit when told so.
if FLAGS.only_check_args:
sys.exit(0)
# Immediately after flags are parsed, bump verbosity to INFO if the flag has
# not been set.
if FLAGS['verbosity'].using_default_value:
FLAGS.verbosity = 0
_register_and_parse_flags_with_usage.done = True
return args_to_main
_register_and_parse_flags_with_usage.done = False
def _run_main(main, argv):
"""Calls main, optionally with pdb or profiler."""
if FLAGS.run_with_pdb:
sys.exit(pdb.runcall(main, argv))
elif FLAGS.run_with_profiling or FLAGS.profile_file:
# Avoid import overhead since most apps (including performance-sensitive
# ones) won't be run with profiling.
# pylint: disable=g-import-not-at-top
import atexit
if FLAGS.use_cprofile_for_profiling:
import cProfile as profile
else:
import profile
profiler = profile.Profile()
if FLAGS.profile_file:
atexit.register(profiler.dump_stats, FLAGS.profile_file)
else:
atexit.register(profiler.print_stats)
sys.exit(profiler.runcall(main, argv))
else:
sys.exit(main(argv))
def _call_exception_handlers(exception):
"""Calls any installed exception handlers."""
for handler in EXCEPTION_HANDLERS:
try:
if handler.wants(exception):
handler.handle(exception)
except: # pylint: disable=bare-except
try:
# We don't want to stop for exceptions in the exception handlers but
# we shouldn't hide them either.
logging.error(traceback.format_exc())
except: # pylint: disable=bare-except
# In case even the logging statement fails, ignore.
pass
def run(
main,
argv=None,
flags_parser=parse_flags_with_usage,
):
"""Begins executing the program.
Args:
main: The main function to execute. It takes an single argument "argv",
which is a list of command line arguments with parsed flags removed.
The return value is passed to `sys.exit`, and so for example
a return value of 0 or None results in a successful termination, whereas
a return value of 1 results in abnormal termination.
For more details, see https://docs.python.org/3/library/sys#sys.exit
argv: A non-empty list of the command line arguments including program name,
sys.argv is used if None.
flags_parser: Callable[[List[Text]], Any], the function used to parse flags.
The return value of this function is passed to `main` untouched.
It must guarantee FLAGS is parsed after this function is called.
Should be passed as a keyword-only arg which will become mandatory in a
future release.
- Parses command line flags with the flag module.
- If there are any errors, prints usage().
- Calls main() with the remaining arguments.
- If main() raises a UsageError, prints usage and the error message.
"""
try:
args = _run_init(
sys.argv if argv is None else argv,
flags_parser,
)
while _init_callbacks:
callback = _init_callbacks.popleft()
callback()
try:
_run_main(main, args)
except UsageError as error:
usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)
except:
exc = sys.exc_info()[1]
# Don't try to post-mortem debug successful SystemExits, since those
# mean there wasn't actually an error. In particular, the test framework
# raises SystemExit(False) even if all tests passed.
if isinstance(exc, SystemExit) and not exc.code:
raise
# Check the tty so that we don't hang waiting for input in an
# non-interactive scenario.
if FLAGS.pdb_post_mortem and sys.stdout.isatty():
traceback.print_exc()
print()
print(' *** Entering post-mortem debugging ***')
print()
pdb.post_mortem()
raise
except Exception as e:
_call_exception_handlers(e)
raise
# Callbacks which have been deferred until after _run_init has been called.
_init_callbacks = collections.deque()
def call_after_init(callback):
"""Calls the given callback only once ABSL has finished initialization.
If ABSL has already finished initialization when ``call_after_init`` is
called then the callback is executed immediately, otherwise `callback` is
stored to be executed after ``app.run`` has finished initializing (aka. just
before the main function is called).
If called after ``app.run``, this is equivalent to calling ``callback()`` in
the caller thread. If called before ``app.run``, callbacks are run
sequentially (in an undefined order) in the same thread as ``app.run``.
Args:
callback: a callable to be called once ABSL has finished initialization.
This may be immediate if initialization has already finished. It
takes no arguments and returns nothing.
"""
if _run_init.done:
callback()
else:
_init_callbacks.append(callback)
def _run_init(
argv,
flags_parser,
):
"""Does one-time initialization and re-parses flags on rerun."""
if _run_init.done:
return flags_parser(argv)
command_name.make_process_name_useful()
# Set up absl logging handler.
logging.use_absl_handler()
args = _register_and_parse_flags_with_usage(
argv=argv,
flags_parser=flags_parser,
)
if faulthandler:
try:
faulthandler.enable()
except Exception: # pylint: disable=broad-except
# Some tests verify stderr output very closely, so don't print anything.
# Disabled faulthandler is a low-impact error.
pass
_run_init.done = True
return args
_run_init.done = False
def usage(shorthelp=False, writeto_stdout=False, detailed_error=None,
exitcode=None):
"""Writes __main__'s docstring to stderr with some help text.
Args:
shorthelp: bool, if True, prints only flags from the main module,
rather than all flags.
writeto_stdout: bool, if True, writes help message to stdout,
rather than to stderr.
detailed_error: str, additional detail about why usage info was presented.
exitcode: optional integer, if set, exits with this status code after
writing help.
"""
if writeto_stdout:
stdfile = sys.stdout
else:
stdfile = sys.stderr
doc = sys.modules['__main__'].__doc__
if not doc:
doc = '\nUSAGE: %s [flags]\n' % sys.argv[0]
doc = flags.text_wrap(doc, indent=' ', firstline_indent='')
else:
# Replace all '%s' with sys.argv[0], and all '%%' with '%'.
num_specifiers = doc.count('%') - 2 * doc.count('%%')
try:
doc %= (sys.argv[0],) * num_specifiers
except (OverflowError, TypeError, ValueError):
# Just display the docstring as-is.
pass
if shorthelp:
flag_str = FLAGS.main_module_help()
else:
flag_str = FLAGS.get_help()
try:
stdfile.write(doc)
if flag_str:
stdfile.write('\nflags:\n')
stdfile.write(flag_str)
stdfile.write('\n')
if detailed_error is not None:
stdfile.write('\n%s\n' % detailed_error)
except IOError as e:
# We avoid printing a huge backtrace if we get EPIPE, because
# "foo.par --help | less" is a frequent use case.
if e.errno != errno.EPIPE:
raise
if exitcode is not None:
sys.exit(exitcode)
class ExceptionHandler(object):
"""Base exception handler from which other may inherit."""
def wants(self, exc):
"""Returns whether this handler wants to handle the exception or not.
This base class returns True for all exceptions by default. Override in
subclass if it wants to be more selective.
Args:
exc: Exception, the current exception.
"""
del exc # Unused.
return True
def handle(self, exc):
"""Do something with the current exception.
Args:
exc: Exception, the current exception
This method must be overridden.
"""
raise NotImplementedError()
def install_exception_handler(handler):
"""Installs an exception handler.
Args:
handler: ExceptionHandler, the exception handler to install.
Raises:
TypeError: Raised when the handler was not of the correct type.
All installed exception handlers will be called if main() exits via
an abnormal exception, i.e. not one of SystemExit, KeyboardInterrupt,
FlagsError or UsageError.
"""
if not isinstance(handler, ExceptionHandler):
raise TypeError('handler of type %s does not inherit from ExceptionHandler'
% type(handler))
EXCEPTION_HANDLERS.append(handler)
abseil-py-2.1.0/absl/app.pyi 0000664 0000000 0000000 00000003311 14551576331 0015633 0 ustar 00root root 0000000 0000000
from typing import Any, Callable, Collection, Iterable, List, NoReturn, Optional, Text, TypeVar, Union, overload
from absl.flags import _flag
_MainArgs = TypeVar('_MainArgs')
_Exc = TypeVar('_Exc', bound=Exception)
class ExceptionHandler():
def wants(self, exc: _Exc) -> bool:
...
def handle(self, exc: _Exc):
...
EXCEPTION_HANDLERS: List[ExceptionHandler] = ...
class HelpFlag(_flag.BooleanFlag):
def __init__(self):
...
class HelpshortFlag(HelpFlag):
...
class HelpfullFlag(_flag.BooleanFlag):
def __init__(self):
...
class HelpXMLFlag(_flag.BooleanFlag):
def __init__(self):
...
def define_help_flags() -> None:
...
@overload
def usage(shorthelp: Union[bool, int] = ...,
writeto_stdout: Union[bool, int] = ...,
detailed_error: Optional[Any] = ...,
exitcode: None = ...) -> None:
...
@overload
def usage(shorthelp: Union[bool, int] = ...,
writeto_stdout: Union[bool, int] = ...,
detailed_error: Optional[Any] = ...,
exitcode: int = ...) -> NoReturn:
...
def install_exception_handler(handler: ExceptionHandler) -> None:
...
class Error(Exception):
...
class UsageError(Error):
exitcode: int
def parse_flags_with_usage(args: List[Text]) -> List[Text]:
...
def call_after_init(callback: Callable[[], Any]) -> None:
...
# Without the flag_parser argument, `main` should require a List[Text].
@overload
def run(
main: Callable[[List[Text]], Any],
argv: Optional[List[Text]] = ...,
*,
) -> NoReturn:
...
@overload
def run(
main: Callable[[_MainArgs], Any],
argv: Optional[List[Text]] = ...,
*,
flags_parser: Callable[[List[Text]], _MainArgs],
) -> NoReturn:
...
abseil-py-2.1.0/absl/command_name.py 0000664 0000000 0000000 00000004375 14551576331 0017333 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A tiny stand alone library to change the kernel process name on Linux."""
import os
import sys
# This library must be kept small and stand alone. It is used by small things
# that require no extension modules.
def make_process_name_useful():
"""Sets the process name to something better than 'python' if possible."""
set_kernel_process_name(os.path.basename(sys.argv[0]))
def set_kernel_process_name(name):
"""Changes the Kernel's /proc/self/status process name on Linux.
The kernel name is NOT what will be shown by the ps or top command.
It is a 15 character string stored in the kernel's process table that
is included in the kernel log when a process is OOM killed.
The first 15 bytes of name are used. Non-ASCII unicode is replaced with '?'.
Does nothing if /proc/self/comm cannot be written or prctl() fails.
Args:
name: bytes|unicode, the Linux kernel's command name to set.
"""
if not isinstance(name, bytes):
name = name.encode('ascii', 'replace')
try:
# This is preferred to using ctypes to try and call prctl() when possible.
with open('/proc/self/comm', 'wb') as proc_comm:
proc_comm.write(name[:15])
except EnvironmentError:
try:
import ctypes # pylint: disable=g-import-not-at-top
except ImportError:
return # No ctypes.
try:
libc = ctypes.CDLL('libc.so.6')
except EnvironmentError:
return # No libc.so.6.
pr_set_name = ctypes.c_ulong(15) # linux/prctl.h PR_SET_NAME value.
zero = ctypes.c_ulong(0)
try:
libc.prctl(pr_set_name, name, zero, zero, zero)
# Ignore the prctl return value. Nothing we can do if it errored.
except AttributeError:
return # No prctl.
abseil-py-2.1.0/absl/flags/ 0000775 0000000 0000000 00000000000 14551576331 0015426 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/absl/flags/BUILD 0000664 0000000 0000000 00000014562 14551576331 0016220 0 ustar 00root root 0000000 0000000 load("@rules_python//python:py_library.bzl", "py_library")
load("@rules_python//python:py_test.bzl", "py_test")
load("@rules_python//python:py_binary.bzl", "py_binary")
package(default_visibility = ["//visibility:private"])
licenses(["notice"])
py_library(
name = "flags",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":_argument_parser",
":_defines",
":_exceptions",
":_flag",
":_flagvalues",
":_helpers",
":_validators",
],
)
py_library(
name = "argparse_flags",
srcs = ["argparse_flags.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [":flags"],
)
py_library(
name = "_argument_parser",
srcs = ["_argument_parser.py"],
srcs_version = "PY2AND3",
deps = [
":_helpers",
],
)
py_library(
name = "_defines",
srcs = ["_defines.py"],
srcs_version = "PY2AND3",
deps = [
":_argument_parser",
":_exceptions",
":_flag",
":_flagvalues",
":_helpers",
":_validators",
],
)
py_library(
name = "_exceptions",
srcs = ["_exceptions.py"],
srcs_version = "PY2AND3",
deps = [
":_helpers",
],
)
py_library(
name = "_flag",
srcs = ["_flag.py"],
srcs_version = "PY2AND3",
deps = [
":_argument_parser",
":_exceptions",
":_helpers",
],
)
py_library(
name = "_flagvalues",
srcs = ["_flagvalues.py"],
srcs_version = "PY2AND3",
deps = [
":_exceptions",
":_flag",
":_helpers",
":_validators_classes",
],
)
py_library(
name = "_helpers",
srcs = ["_helpers.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "_validators",
srcs = [
"_validators.py",
],
srcs_version = "PY2AND3",
deps = [
":_exceptions",
":_flagvalues",
":_validators_classes",
],
)
py_library(
name = "_validators_classes",
srcs = [
"_validators_classes.py",
],
srcs_version = "PY2AND3",
deps = [
":_exceptions",
],
)
py_test(
name = "tests/_argument_parser_test",
srcs = ["tests/_argument_parser_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_argument_parser",
"//absl/testing:absltest",
"//absl/testing:parameterized",
],
)
py_test(
name = "tests/_flag_test",
srcs = ["tests/_flag_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_argument_parser",
":_exceptions",
":_flag",
"//absl/testing:absltest",
"//absl/testing:parameterized",
],
)
py_test(
name = "tests/_flagvalues_test",
size = "small",
srcs = ["tests/_flagvalues_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_defines",
":_exceptions",
":_flagvalues",
":_helpers",
":_validators",
":tests/module_foo",
"//absl/logging",
"//absl/testing:absltest",
"//absl/testing:parameterized",
],
)
py_test(
name = "tests/_helpers_test",
size = "small",
srcs = ["tests/_helpers_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_helpers",
":tests/module_bar",
":tests/module_foo",
"//absl/testing:absltest",
],
)
py_test(
name = "tests/_validators_test",
size = "small",
srcs = ["tests/_validators_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_defines",
":_exceptions",
":_flagvalues",
":_validators",
"//absl/testing:absltest",
],
)
py_test(
name = "tests/argparse_flags_test",
size = "small",
srcs = ["tests/argparse_flags_test.py"],
data = [":tests/argparse_flags_test_helper"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":argparse_flags",
":flags",
"//absl/logging",
"//absl/testing:_bazelize_command",
"//absl/testing:absltest",
"//absl/testing:parameterized",
],
)
py_binary(
name = "tests/argparse_flags_test_helper",
testonly = 1,
srcs = ["tests/argparse_flags_test_helper.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":argparse_flags",
":flags",
"//absl:app",
],
)
py_test(
name = "tests/flags_formatting_test",
size = "small",
srcs = ["tests/flags_formatting_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_helpers",
":flags",
"//absl/testing:absltest",
],
)
py_test(
name = "tests/flags_helpxml_test",
size = "small",
srcs = ["tests/flags_helpxml_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_helpers",
":flags",
":tests/module_bar",
"//absl/testing:absltest",
],
)
py_test(
name = "tests/flags_numeric_bounds_test",
size = "small",
srcs = ["tests/flags_numeric_bounds_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_validators",
":flags",
"//absl/testing:absltest",
],
)
py_test(
name = "tests/flags_test",
size = "small",
srcs = ["tests/flags_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_exceptions",
":_helpers",
":flags",
":tests/module_bar",
":tests/module_baz",
":tests/module_foo",
"//absl/testing:absltest",
],
)
py_test(
name = "tests/flags_unicode_literals_test",
size = "small",
srcs = ["tests/flags_unicode_literals_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":flags",
"//absl/testing:absltest",
],
)
py_library(
name = "tests/module_bar",
testonly = 1,
srcs = ["tests/module_bar.py"],
srcs_version = "PY2AND3",
deps = [
":_helpers",
":flags",
],
)
py_library(
name = "tests/module_baz",
testonly = 1,
srcs = ["tests/module_baz.py"],
srcs_version = "PY2AND3",
deps = [":flags"],
)
py_library(
name = "tests/module_foo",
testonly = 1,
srcs = ["tests/module_foo.py"],
srcs_version = "PY2AND3",
deps = [
":_helpers",
":flags",
":tests/module_bar",
],
)
abseil-py-2.1.0/absl/flags/__init__.py 0000664 0000000 0000000 00000017060 14551576331 0017543 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This package is used to define and parse command line flags.
This package defines a *distributed* flag-definition policy: rather than
an application having to define all flags in or near main(), each Python
module defines flags that are useful to it. When one Python module
imports another, it gains access to the other's flags. (This is
implemented by having all modules share a common, global registry object
containing all the flag information.)
Flags are defined through the use of one of the DEFINE_xxx functions.
The specific function used determines how the flag is parsed, checked,
and optionally type-converted, when it's seen on the command line.
"""
import getopt
import os
import re
import sys
import types
import warnings
from absl.flags import _argument_parser
from absl.flags import _defines
from absl.flags import _exceptions
from absl.flags import _flag
from absl.flags import _flagvalues
from absl.flags import _helpers
from absl.flags import _validators
__all__ = (
'DEFINE',
'DEFINE_flag',
'DEFINE_string',
'DEFINE_boolean',
'DEFINE_bool',
'DEFINE_float',
'DEFINE_integer',
'DEFINE_enum',
'DEFINE_enum_class',
'DEFINE_list',
'DEFINE_spaceseplist',
'DEFINE_multi',
'DEFINE_multi_string',
'DEFINE_multi_integer',
'DEFINE_multi_float',
'DEFINE_multi_enum',
'DEFINE_multi_enum_class',
'DEFINE_alias',
# Flag validators.
'register_validator',
'validator',
'register_multi_flags_validator',
'multi_flags_validator',
'mark_flag_as_required',
'mark_flags_as_required',
'mark_flags_as_mutual_exclusive',
'mark_bool_flags_as_mutual_exclusive',
# Flag modifiers.
'set_default',
'override_value',
# Key flag related functions.
'declare_key_flag',
'adopt_module_key_flags',
'disclaim_key_flags',
# Module exceptions.
'Error',
'CantOpenFlagFileError',
'DuplicateFlagError',
'IllegalFlagValueError',
'UnrecognizedFlagError',
'UnparsedFlagAccessError',
'ValidationError',
'FlagNameConflictsWithMethodError',
# Public classes.
'Flag',
'BooleanFlag',
'EnumFlag',
'EnumClassFlag',
'MultiFlag',
'MultiEnumClassFlag',
'FlagHolder',
'FlagValues',
'ArgumentParser',
'BooleanParser',
'EnumParser',
'EnumClassParser',
'ArgumentSerializer',
'FloatParser',
'IntegerParser',
'BaseListParser',
'ListParser',
'ListSerializer',
'EnumClassListSerializer',
'CsvListSerializer',
'WhitespaceSeparatedListParser',
'EnumClassSerializer',
# Helper functions.
'get_help_width',
'text_wrap',
'flag_dict_to_args',
'doc_to_help',
# The global FlagValues instance.
'FLAGS',
)
# Initialize the FLAGS_MODULE as early as possible.
# It's only used by adopt_module_key_flags to take SPECIAL_FLAGS into account.
_helpers.FLAGS_MODULE = sys.modules[__name__]
# Add current module to disclaimed module ids.
_helpers.disclaim_module_ids.add(id(sys.modules[__name__]))
# DEFINE functions. They are explained in more details in the module doc string.
# pylint: disable=invalid-name
DEFINE = _defines.DEFINE
DEFINE_flag = _defines.DEFINE_flag
DEFINE_string = _defines.DEFINE_string
DEFINE_boolean = _defines.DEFINE_boolean
DEFINE_bool = DEFINE_boolean # Match C++ API.
DEFINE_float = _defines.DEFINE_float
DEFINE_integer = _defines.DEFINE_integer
DEFINE_enum = _defines.DEFINE_enum
DEFINE_enum_class = _defines.DEFINE_enum_class
DEFINE_list = _defines.DEFINE_list
DEFINE_spaceseplist = _defines.DEFINE_spaceseplist
DEFINE_multi = _defines.DEFINE_multi
DEFINE_multi_string = _defines.DEFINE_multi_string
DEFINE_multi_integer = _defines.DEFINE_multi_integer
DEFINE_multi_float = _defines.DEFINE_multi_float
DEFINE_multi_enum = _defines.DEFINE_multi_enum
DEFINE_multi_enum_class = _defines.DEFINE_multi_enum_class
DEFINE_alias = _defines.DEFINE_alias
# pylint: enable=invalid-name
# Flag validators.
register_validator = _validators.register_validator
validator = _validators.validator
register_multi_flags_validator = _validators.register_multi_flags_validator
multi_flags_validator = _validators.multi_flags_validator
mark_flag_as_required = _validators.mark_flag_as_required
mark_flags_as_required = _validators.mark_flags_as_required
mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive
mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive
# Flag modifiers.
set_default = _defines.set_default
override_value = _defines.override_value
# Key flag related functions.
declare_key_flag = _defines.declare_key_flag
adopt_module_key_flags = _defines.adopt_module_key_flags
disclaim_key_flags = _defines.disclaim_key_flags
# Module exceptions.
# pylint: disable=invalid-name
Error = _exceptions.Error
CantOpenFlagFileError = _exceptions.CantOpenFlagFileError
DuplicateFlagError = _exceptions.DuplicateFlagError
IllegalFlagValueError = _exceptions.IllegalFlagValueError
UnrecognizedFlagError = _exceptions.UnrecognizedFlagError
UnparsedFlagAccessError = _exceptions.UnparsedFlagAccessError
ValidationError = _exceptions.ValidationError
FlagNameConflictsWithMethodError = _exceptions.FlagNameConflictsWithMethodError
# Public classes.
Flag = _flag.Flag
BooleanFlag = _flag.BooleanFlag
EnumFlag = _flag.EnumFlag
EnumClassFlag = _flag.EnumClassFlag
MultiFlag = _flag.MultiFlag
MultiEnumClassFlag = _flag.MultiEnumClassFlag
FlagHolder = _flagvalues.FlagHolder
FlagValues = _flagvalues.FlagValues
ArgumentParser = _argument_parser.ArgumentParser
BooleanParser = _argument_parser.BooleanParser
EnumParser = _argument_parser.EnumParser
EnumClassParser = _argument_parser.EnumClassParser
ArgumentSerializer = _argument_parser.ArgumentSerializer
FloatParser = _argument_parser.FloatParser
IntegerParser = _argument_parser.IntegerParser
BaseListParser = _argument_parser.BaseListParser
ListParser = _argument_parser.ListParser
ListSerializer = _argument_parser.ListSerializer
EnumClassListSerializer = _argument_parser.EnumClassListSerializer
CsvListSerializer = _argument_parser.CsvListSerializer
WhitespaceSeparatedListParser = _argument_parser.WhitespaceSeparatedListParser
EnumClassSerializer = _argument_parser.EnumClassSerializer
# pylint: enable=invalid-name
# Helper functions.
get_help_width = _helpers.get_help_width
text_wrap = _helpers.text_wrap
flag_dict_to_args = _helpers.flag_dict_to_args
doc_to_help = _helpers.doc_to_help
# Special flags.
_helpers.SPECIAL_FLAGS = FlagValues()
DEFINE_string(
'flagfile', '',
'Insert flag definitions from the given file into the command line.',
_helpers.SPECIAL_FLAGS) # pytype: disable=wrong-arg-types
DEFINE_string('undefok', '',
'comma-separated list of flag names that it is okay to specify '
'on the command line even if the program does not define a flag '
'with that name. IMPORTANT: flags in this list that have '
'arguments MUST use the --flag=value format.',
_helpers.SPECIAL_FLAGS) # pytype: disable=wrong-arg-types
#: The global FlagValues instance.
FLAGS = _flagvalues.FLAGS
abseil-py-2.1.0/absl/flags/_argument_parser.py 0000664 0000000 0000000 00000050706 14551576331 0021345 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains base classes used to parse and convert arguments.
Do NOT import this module directly. Import the flags package and use the
aliases defined at the package level instead.
"""
import collections
import csv
import enum
import io
import string
from typing import Generic, List, Iterable, Optional, Sequence, Text, Type, TypeVar, Union
from xml.dom import minidom
from absl.flags import _helpers
_T = TypeVar('_T')
_ET = TypeVar('_ET', bound=enum.Enum)
_N = TypeVar('_N', int, float)
def _is_integer_type(instance):
"""Returns True if instance is an integer, and not a bool."""
return (isinstance(instance, int) and
not isinstance(instance, bool))
class _ArgumentParserCache(type):
"""Metaclass used to cache and share argument parsers among flags."""
_instances = {}
def __call__(cls, *args, **kwargs):
"""Returns an instance of the argument parser cls.
This method overrides behavior of the __new__ methods in
all subclasses of ArgumentParser (inclusive). If an instance
for cls with the same set of arguments exists, this instance is
returned, otherwise a new instance is created.
If any keyword arguments are defined, or the values in args
are not hashable, this method always returns a new instance of
cls.
Args:
*args: Positional initializer arguments.
**kwargs: Initializer keyword arguments.
Returns:
An instance of cls, shared or new.
"""
if kwargs:
return type.__call__(cls, *args, **kwargs)
else:
instances = cls._instances
key = (cls,) + tuple(args)
try:
return instances[key]
except KeyError:
# No cache entry for key exists, create a new one.
return instances.setdefault(key, type.__call__(cls, *args))
except TypeError:
# An object in args cannot be hashed, always return
# a new instance.
return type.__call__(cls, *args)
class ArgumentParser(Generic[_T], metaclass=_ArgumentParserCache):
"""Base class used to parse and convert arguments.
The :meth:`parse` method checks to make sure that the string argument is a
legal value and convert it to a native type. If the value cannot be
converted, it should throw a ``ValueError`` exception with a human
readable explanation of why the value is illegal.
Subclasses should also define a syntactic_help string which may be
presented to the user to describe the form of the legal values.
Argument parser classes must be stateless, since instances are cached
and shared between flags. Initializer arguments are allowed, but all
member variables must be derived from initializer arguments only.
"""
syntactic_help: Text = ''
def parse(self, argument: Text) -> Optional[_T]:
"""Parses the string argument and returns the native value.
By default it returns its argument unmodified.
Args:
argument: string argument passed in the commandline.
Raises:
ValueError: Raised when it fails to parse the argument.
TypeError: Raised when the argument has the wrong type.
Returns:
The parsed value in native type.
"""
if not isinstance(argument, str):
raise TypeError('flag value must be a string, found "{}"'.format(
type(argument)))
return argument
def flag_type(self) -> Text:
"""Returns a string representing the type of the flag."""
return 'string'
def _custom_xml_dom_elements(
self, doc: minidom.Document
) -> List[minidom.Element]:
"""Returns a list of minidom.Element to add additional flag information.
Args:
doc: minidom.Document, the DOM document it should create nodes from.
"""
del doc # Unused.
return []
class ArgumentSerializer(Generic[_T]):
"""Base class for generating string representations of a flag value."""
def serialize(self, value: _T) -> Text:
"""Returns a serialized string of the value."""
return str(value)
class NumericParser(ArgumentParser[_N]):
"""Parser of numeric values.
Parsed value may be bounded to a given upper and lower bound.
"""
lower_bound: Optional[_N]
upper_bound: Optional[_N]
def is_outside_bounds(self, val: _N) -> bool:
"""Returns whether the value is outside the bounds or not."""
return ((self.lower_bound is not None and val < self.lower_bound) or
(self.upper_bound is not None and val > self.upper_bound))
def parse(self, argument: Text) -> _N:
"""See base class."""
val = self.convert(argument)
if self.is_outside_bounds(val):
raise ValueError('%s is not %s' % (val, self.syntactic_help))
return val
def _custom_xml_dom_elements(
self, doc: minidom.Document
) -> List[minidom.Element]:
elements = []
if self.lower_bound is not None:
elements.append(_helpers.create_xml_dom_element(
doc, 'lower_bound', self.lower_bound))
if self.upper_bound is not None:
elements.append(_helpers.create_xml_dom_element(
doc, 'upper_bound', self.upper_bound))
return elements
def convert(self, argument: Text) -> _N:
"""Returns the correct numeric value of argument.
Subclass must implement this method, and raise TypeError if argument is not
string or has the right numeric type.
Args:
argument: string argument passed in the commandline, or the numeric type.
Raises:
TypeError: Raised when argument is not a string or the right numeric type.
ValueError: Raised when failed to convert argument to the numeric value.
"""
raise NotImplementedError
class FloatParser(NumericParser[float]):
"""Parser of floating point values.
Parsed value may be bounded to a given upper and lower bound.
"""
number_article = 'a'
number_name = 'number'
syntactic_help = ' '.join((number_article, number_name))
def __init__(
self,
lower_bound: Optional[float] = None,
upper_bound: Optional[float] = None,
) -> None:
super(FloatParser, self).__init__()
self.lower_bound = lower_bound
self.upper_bound = upper_bound
sh = self.syntactic_help
if lower_bound is not None and upper_bound is not None:
sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
elif lower_bound == 0:
sh = 'a non-negative %s' % self.number_name
elif upper_bound == 0:
sh = 'a non-positive %s' % self.number_name
elif upper_bound is not None:
sh = '%s <= %s' % (self.number_name, upper_bound)
elif lower_bound is not None:
sh = '%s >= %s' % (self.number_name, lower_bound)
self.syntactic_help = sh
def convert(self, argument: Union[int, float, str]) -> float:
"""Returns the float value of argument."""
if (_is_integer_type(argument) or isinstance(argument, float) or
isinstance(argument, str)):
return float(argument)
else:
raise TypeError(
'Expect argument to be a string, int, or float, found {}'.format(
type(argument)))
def flag_type(self) -> Text:
"""See base class."""
return 'float'
class IntegerParser(NumericParser[int]):
"""Parser of an integer value.
Parsed value may be bounded to a given upper and lower bound.
"""
number_article = 'an'
number_name = 'integer'
syntactic_help = ' '.join((number_article, number_name))
def __init__(
self, lower_bound: Optional[int] = None, upper_bound: Optional[int] = None
) -> None:
super(IntegerParser, self).__init__()
self.lower_bound = lower_bound
self.upper_bound = upper_bound
sh = self.syntactic_help
if lower_bound is not None and upper_bound is not None:
sh = ('%s in the range [%s, %s]' % (sh, lower_bound, upper_bound))
elif lower_bound == 1:
sh = 'a positive %s' % self.number_name
elif upper_bound == -1:
sh = 'a negative %s' % self.number_name
elif lower_bound == 0:
sh = 'a non-negative %s' % self.number_name
elif upper_bound == 0:
sh = 'a non-positive %s' % self.number_name
elif upper_bound is not None:
sh = '%s <= %s' % (self.number_name, upper_bound)
elif lower_bound is not None:
sh = '%s >= %s' % (self.number_name, lower_bound)
self.syntactic_help = sh
def convert(self, argument: Union[int, Text]) -> int:
"""Returns the int value of argument."""
if _is_integer_type(argument):
return argument
elif isinstance(argument, str):
base = 10
if len(argument) > 2 and argument[0] == '0':
if argument[1] == 'o':
base = 8
elif argument[1] == 'x':
base = 16
return int(argument, base)
else:
raise TypeError('Expect argument to be a string or int, found {}'.format(
type(argument)))
def flag_type(self) -> Text:
"""See base class."""
return 'int'
class BooleanParser(ArgumentParser[bool]):
"""Parser of boolean values."""
def parse(self, argument: Union[Text, int]) -> bool:
"""See base class."""
if isinstance(argument, str):
if argument.lower() in ('true', 't', '1'):
return True
elif argument.lower() in ('false', 'f', '0'):
return False
else:
raise ValueError('Non-boolean argument to boolean flag', argument)
elif isinstance(argument, int):
# Only allow bool or integer 0, 1.
# Note that float 1.0 == True, 0.0 == False.
bool_value = bool(argument)
if argument == bool_value:
return bool_value
else:
raise ValueError('Non-boolean argument to boolean flag', argument)
raise TypeError('Non-boolean argument to boolean flag', argument)
def flag_type(self) -> Text:
"""See base class."""
return 'bool'
class EnumParser(ArgumentParser[Text]):
"""Parser of a string enum value (a string value from a given set)."""
def __init__(
self, enum_values: Iterable[Text], case_sensitive: bool = True
) -> None:
"""Initializes EnumParser.
Args:
enum_values: [str], a non-empty list of string values in the enum.
case_sensitive: bool, whether or not the enum is to be case-sensitive.
Raises:
ValueError: When enum_values is empty.
"""
if not enum_values:
raise ValueError(
'enum_values cannot be empty, found "{}"'.format(enum_values))
if isinstance(enum_values, str):
raise ValueError(
'enum_values cannot be a str, found "{}"'.format(enum_values)
)
super(EnumParser, self).__init__()
self.enum_values = list(enum_values)
self.case_sensitive = case_sensitive
def parse(self, argument: Text) -> Text:
"""Determines validity of argument and returns the correct element of enum.
Args:
argument: str, the supplied flag value.
Returns:
The first matching element from enum_values.
Raises:
ValueError: Raised when argument didn't match anything in enum.
"""
if self.case_sensitive:
if argument not in self.enum_values:
raise ValueError('value should be one of <%s>' %
'|'.join(self.enum_values))
else:
return argument
else:
if argument.upper() not in [value.upper() for value in self.enum_values]:
raise ValueError('value should be one of <%s>' %
'|'.join(self.enum_values))
else:
return [value for value in self.enum_values
if value.upper() == argument.upper()][0]
def flag_type(self) -> Text:
"""See base class."""
return 'string enum'
class EnumClassParser(ArgumentParser[_ET]):
"""Parser of an Enum class member."""
def __init__(
self, enum_class: Type[_ET], case_sensitive: bool = True
) -> None:
"""Initializes EnumParser.
Args:
enum_class: class, the Enum class with all possible flag values.
case_sensitive: bool, whether or not the enum is to be case-sensitive. If
False, all member names must be unique when case is ignored.
Raises:
TypeError: When enum_class is not a subclass of Enum.
ValueError: When enum_class is empty.
"""
if not issubclass(enum_class, enum.Enum):
raise TypeError('{} is not a subclass of Enum.'.format(enum_class))
if not enum_class.__members__:
raise ValueError('enum_class cannot be empty, but "{}" is empty.'
.format(enum_class))
if not case_sensitive:
members = collections.Counter(
name.lower() for name in enum_class.__members__)
duplicate_keys = {
member for member, count in members.items() if count > 1
}
if duplicate_keys:
raise ValueError(
'Duplicate enum values for {} using case_sensitive=False'.format(
duplicate_keys))
super(EnumClassParser, self).__init__()
self.enum_class = enum_class
self._case_sensitive = case_sensitive
if case_sensitive:
self._member_names = tuple(enum_class.__members__)
else:
self._member_names = tuple(
name.lower() for name in enum_class.__members__)
@property
def member_names(self) -> Sequence[Text]:
"""The accepted enum names, in lowercase if not case sensitive."""
return self._member_names
def parse(self, argument: Union[_ET, Text]) -> _ET:
"""Determines validity of argument and returns the correct element of enum.
Args:
argument: str or Enum class member, the supplied flag value.
Returns:
The first matching Enum class member in Enum class.
Raises:
ValueError: Raised when argument didn't match anything in enum.
"""
if isinstance(argument, self.enum_class):
return argument # pytype: disable=bad-return-type
elif not isinstance(argument, str):
raise ValueError(
'{} is not an enum member or a name of a member in {}'.format(
argument, self.enum_class))
key = EnumParser(
self._member_names, case_sensitive=self._case_sensitive).parse(argument)
if self._case_sensitive:
return self.enum_class[key]
else:
# If EnumParser.parse() return a value, we're guaranteed to find it
# as a member of the class
return next(value for name, value in self.enum_class.__members__.items()
if name.lower() == key.lower())
def flag_type(self) -> Text:
"""See base class."""
return 'enum class'
class ListSerializer(Generic[_T], ArgumentSerializer[List[_T]]):
def __init__(self, list_sep: Text) -> None:
self.list_sep = list_sep
def serialize(self, value: List[_T]) -> Text:
"""See base class."""
return self.list_sep.join([str(x) for x in value])
class EnumClassListSerializer(ListSerializer[_ET]):
"""A serializer for :class:`MultiEnumClass` flags.
This serializer simply joins the output of `EnumClassSerializer` using a
provided separator.
"""
def __init__(self, list_sep: Text, **kwargs) -> None:
"""Initializes EnumClassListSerializer.
Args:
list_sep: String to be used as a separator when serializing
**kwargs: Keyword arguments to the `EnumClassSerializer` used to serialize
individual values.
"""
super(EnumClassListSerializer, self).__init__(list_sep)
self._element_serializer = EnumClassSerializer(**kwargs)
def serialize(self, value: Union[_ET, List[_ET]]) -> Text:
"""See base class."""
if isinstance(value, list):
return self.list_sep.join(
self._element_serializer.serialize(x) for x in value)
else:
return self._element_serializer.serialize(value)
class CsvListSerializer(ListSerializer[Text]):
def serialize(self, value: List[Text]) -> Text:
"""Serializes a list as a CSV string or unicode."""
output = io.StringIO()
writer = csv.writer(output, delimiter=self.list_sep)
writer.writerow([str(x) for x in value])
serialized_value = output.getvalue().strip()
# We need the returned value to be pure ascii or Unicodes so that
# when the xml help is generated they are usefully encodable.
return str(serialized_value)
class EnumClassSerializer(ArgumentSerializer[_ET]):
"""Class for generating string representations of an enum class flag value."""
def __init__(self, lowercase: bool) -> None:
"""Initializes EnumClassSerializer.
Args:
lowercase: If True, enum member names are lowercased during serialization.
"""
self._lowercase = lowercase
def serialize(self, value: _ET) -> Text:
"""Returns a serialized string of the Enum class value."""
as_string = str(value.name)
return as_string.lower() if self._lowercase else as_string
class BaseListParser(ArgumentParser):
"""Base class for a parser of lists of strings.
To extend, inherit from this class; from the subclass ``__init__``, call::
super().__init__(token, name)
where token is a character used to tokenize, and name is a description
of the separator.
"""
def __init__(
self, token: Optional[Text] = None, name: Optional[Text] = None
) -> None:
assert name
super(BaseListParser, self).__init__()
self._token = token
self._name = name
self.syntactic_help = 'a %s separated list' % self._name
def parse(self, argument: Text) -> List[Text]:
"""See base class."""
if isinstance(argument, list):
return argument
elif not argument:
return []
else:
return [s.strip() for s in argument.split(self._token)]
def flag_type(self) -> Text:
"""See base class."""
return '%s separated list of strings' % self._name
class ListParser(BaseListParser):
"""Parser for a comma-separated list of strings."""
def __init__(self) -> None:
super(ListParser, self).__init__(',', 'comma')
def parse(self, argument: Union[Text, List[Text]]) -> List[Text]:
"""Parses argument as comma-separated list of strings."""
if isinstance(argument, list):
return argument
elif not argument:
return []
else:
try:
return [s.strip() for s in list(csv.reader([argument], strict=True))[0]]
except csv.Error as e:
# Provide a helpful report for case like
# --listflag="$(printf 'hello,\nworld')"
# IOW, list flag values containing naked newlines. This error
# was previously "reported" by allowing csv.Error to
# propagate.
raise ValueError('Unable to parse the value %r as a %s: %s'
% (argument, self.flag_type(), e))
def _custom_xml_dom_elements(
self, doc: minidom.Document
) -> List[minidom.Element]:
elements = super(ListParser, self)._custom_xml_dom_elements(doc)
elements.append(_helpers.create_xml_dom_element(
doc, 'list_separator', repr(',')))
return elements
class WhitespaceSeparatedListParser(BaseListParser):
"""Parser for a whitespace-separated list of strings."""
def __init__(self, comma_compat: bool = False) -> None:
"""Initializer.
Args:
comma_compat: bool, whether to support comma as an additional separator.
If False then only whitespace is supported. This is intended only for
backwards compatibility with flags that used to be comma-separated.
"""
self._comma_compat = comma_compat
name = 'whitespace or comma' if self._comma_compat else 'whitespace'
super(WhitespaceSeparatedListParser, self).__init__(None, name)
def parse(self, argument: Union[Text, List[Text]]) -> List[Text]:
"""Parses argument as whitespace-separated list of strings.
It also parses argument as comma-separated list of strings if requested.
Args:
argument: string argument passed in the commandline.
Returns:
[str], the parsed flag value.
"""
if isinstance(argument, list):
return argument
elif not argument:
return []
else:
if self._comma_compat:
argument = argument.replace(',', ' ')
return argument.split()
def _custom_xml_dom_elements(
self, doc: minidom.Document
) -> List[minidom.Element]:
elements = super(WhitespaceSeparatedListParser, self
)._custom_xml_dom_elements(doc)
separators = list(string.whitespace)
if self._comma_compat:
separators.append(',')
separators.sort()
for sep_char in separators:
elements.append(_helpers.create_xml_dom_element(
doc, 'list_separator', repr(sep_char)))
return elements
abseil-py-2.1.0/absl/flags/_defines.py 0000664 0000000 0000000 00000147057 14551576331 0017572 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This modules contains flags DEFINE functions.
Do NOT import this module directly. Import the flags package and use the
aliases defined at the package level instead.
"""
import enum
import sys
import types
import typing
from typing import Text, List, Any, TypeVar, Optional, Union, Type, Iterable, overload
from absl.flags import _argument_parser
from absl.flags import _exceptions
from absl.flags import _flag
from absl.flags import _flagvalues
from absl.flags import _helpers
from absl.flags import _validators
_helpers.disclaim_module_ids.add(id(sys.modules[__name__]))
_T = TypeVar('_T')
_ET = TypeVar('_ET', bound=enum.Enum)
def _register_bounds_validator_if_needed(parser, name, flag_values):
"""Enforces lower and upper bounds for numeric flags.
Args:
parser: NumericParser (either FloatParser or IntegerParser), provides lower
and upper bounds, and help text to display.
name: str, name of the flag
flag_values: FlagValues.
"""
if parser.lower_bound is not None or parser.upper_bound is not None:
def checker(value):
if value is not None and parser.is_outside_bounds(value):
message = '%s is not %s' % (value, parser.syntactic_help)
raise _exceptions.ValidationError(message)
return True
_validators.register_validator(name, checker, flag_values=flag_values)
@overload
def DEFINE( # pylint: disable=invalid-name
parser: _argument_parser.ArgumentParser[_T],
name: Text,
default: Any,
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ...,
module_name: Optional[Text] = ...,
required: 'typing.Literal[True]' = ...,
**args: Any
) -> _flagvalues.FlagHolder[_T]:
...
@overload
def DEFINE( # pylint: disable=invalid-name
parser: _argument_parser.ArgumentParser[_T],
name: Text,
default: Optional[Any],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
serializer: Optional[_argument_parser.ArgumentSerializer[_T]] = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[_T]]:
...
def DEFINE( # pylint: disable=invalid-name
parser,
name,
default,
help, # pylint: disable=redefined-builtin
flag_values=_flagvalues.FLAGS,
serializer=None,
module_name=None,
required=False,
**args):
"""Registers a generic Flag object.
NOTE: in the docstrings of all DEFINE* functions, "registers" is short
for "creates a new flag and registers it".
Auxiliary function: clients should use the specialized ``DEFINE_``
function instead.
Args:
parser: :class:`ArgumentParser`, used to parse the flag arguments.
name: str, the flag name.
default: The default value of the flag.
help: str, the help message.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
serializer: :class:`ArgumentSerializer`, the flag serializer instance.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: dict, the extra keyword args that are passed to ``Flag.__init__``.
Returns:
a handle to defined flag.
"""
return DEFINE_flag(
_flag.Flag(parser, serializer, name, default, help, **args),
flag_values,
module_name,
required=True if required else False,
)
@overload
def DEFINE_flag( # pylint: disable=invalid-name
flag: _flag.Flag[_T],
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: 'typing.Literal[True]' = ...,
) -> _flagvalues.FlagHolder[_T]:
...
@overload
def DEFINE_flag( # pylint: disable=invalid-name
flag: _flag.Flag[_T],
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
) -> _flagvalues.FlagHolder[Optional[_T]]:
...
def DEFINE_flag( # pylint: disable=invalid-name
flag,
flag_values=_flagvalues.FLAGS,
module_name=None,
required=False):
"""Registers a :class:`Flag` object with a :class:`FlagValues` object.
By default, the global :const:`FLAGS` ``FlagValue`` object is used.
Typical users will use one of the more specialized DEFINE_xxx
functions, such as :func:`DEFINE_string` or :func:`DEFINE_integer`. But
developers who need to create :class:`Flag` objects themselves should use
this function to register their flags.
Args:
flag: :class:`Flag`, a flag that is key to the module.
flag_values: :class:`FlagValues`, the ``FlagValues`` instance with which the
flag will be registered. This should almost never need to be overridden.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
required: bool, is this a required flag. This must be used as a keyword
argument.
Returns:
a handle to defined flag.
"""
if required and flag.default is not None:
raise ValueError('Required flag --%s cannot have a non-None default' %
flag.name)
# Copying the reference to flag_values prevents pychecker warnings.
fv = flag_values
fv[flag.name] = flag
# Tell flag_values who's defining the flag.
if module_name:
module = sys.modules.get(module_name)
else:
module, module_name = _helpers.get_calling_module_object_and_name()
flag_values.register_flag_by_module(module_name, flag)
flag_values.register_flag_by_module_id(id(module), flag)
if required:
_validators.mark_flag_as_required(flag.name, fv)
ensure_non_none_value = (flag.default is not None) or required
return _flagvalues.FlagHolder(
fv, flag, ensure_non_none_value=ensure_non_none_value)
def set_default(flag_holder: _flagvalues.FlagHolder[_T], value: _T) -> None:
"""Changes the default value of the provided flag object.
The flag's current value is also updated if the flag is currently using
the default value, i.e. not specified in the command line, and not set
by FLAGS.name = value.
Args:
flag_holder: FlagHolder, the flag to modify.
value: The new default value.
Raises:
IllegalFlagValueError: Raised when value is not valid.
"""
flag_holder._flagvalues.set_default(flag_holder.name, value) # pylint: disable=protected-access
def override_value(flag_holder: _flagvalues.FlagHolder[_T], value: _T) -> None:
"""Overrides the value of the provided flag.
This value takes precedent over the default value and, when called after flag
parsing, any value provided at the command line.
Args:
flag_holder: FlagHolder, the flag to modify.
value: The new value.
Raises:
IllegalFlagValueError: The value did not pass the flag parser or validators.
"""
fv = flag_holder._flagvalues # pylint: disable=protected-access
# Ensure the new value satisfies the flag's parser while avoiding side
# effects of calling parse().
parsed = fv[flag_holder.name]._parse(value) # pylint: disable=protected-access
if parsed != value:
raise _exceptions.IllegalFlagValueError(
'flag %s: parsed value %r not equal to original %r'
% (flag_holder.name, parsed, value)
)
setattr(fv, flag_holder.name, value)
def _internal_declare_key_flags(
flag_names: List[str],
flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS,
key_flag_values: Optional[_flagvalues.FlagValues] = None,
) -> None:
"""Declares a flag as key for the calling module.
Internal function. User code should call declare_key_flag or
adopt_module_key_flags instead.
Args:
flag_names: [str], a list of names of already-registered Flag objects.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flags listed in flag_names have registered (the value of the flag_values
argument from the ``DEFINE_*`` calls that defined those flags). This
should almost never need to be overridden.
key_flag_values: :class:`FlagValues`, the FlagValues instance that (among
possibly many other things) keeps track of the key flags for each module.
Default ``None`` means "same as flag_values". This should almost never
need to be overridden.
Raises:
UnrecognizedFlagError: Raised when the flag is not defined.
"""
key_flag_values = key_flag_values or flag_values
module = _helpers.get_calling_module()
for flag_name in flag_names:
key_flag_values.register_key_flag_for_module(module, flag_values[flag_name])
def declare_key_flag(
flag_name: Union[Text, _flagvalues.FlagHolder],
flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS,
) -> None:
"""Declares one flag as key to the current module.
Key flags are flags that are deemed really important for a module.
They are important when listing help messages; e.g., if the
--helpshort command-line flag is used, then only the key flags of the
main module are listed (instead of all flags, as in the case of
--helpfull).
Sample usage::
flags.declare_key_flag('flag_1')
Args:
flag_name: str | :class:`FlagHolder`, the name or holder of an already
declared flag. (Redeclaring flags as key, including flags implicitly key
because they were declared in this module, is a no-op.)
Positional-only parameter.
flag_values: :class:`FlagValues`, the FlagValues instance in which the
flag will be declared as a key flag. This should almost never need to be
overridden.
Raises:
ValueError: Raised if flag_name not defined as a Python flag.
"""
flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values)
if flag_name in _helpers.SPECIAL_FLAGS:
# Take care of the special flags, e.g., --flagfile, --undefok.
# These flags are defined in SPECIAL_FLAGS, and are treated
# specially during flag parsing, taking precedence over the
# user-defined flags.
_internal_declare_key_flags([flag_name],
flag_values=_helpers.SPECIAL_FLAGS,
key_flag_values=flag_values)
return
try:
_internal_declare_key_flags([flag_name], flag_values=flag_values)
except KeyError:
raise ValueError('Flag --%s is undefined. To set a flag as a key flag '
'first define it in Python.' % flag_name)
def adopt_module_key_flags(
module: Any, flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS
) -> None:
"""Declares that all flags key to a module are key to the current module.
Args:
module: module, the module object from which all key flags will be declared
as key flags to the current module.
flag_values: :class:`FlagValues`, the FlagValues instance in which the
flags will be declared as key flags. This should almost never need to be
overridden.
Raises:
Error: Raised when given an argument that is a module name (a string),
instead of a module object.
"""
if not isinstance(module, types.ModuleType):
raise _exceptions.Error('Expected a module object, not %r.' % (module,))
_internal_declare_key_flags(
[f.name for f in flag_values.get_key_flags_for_module(module.__name__)],
flag_values=flag_values)
# If module is this flag module, take _helpers.SPECIAL_FLAGS into account.
if module == _helpers.FLAGS_MODULE:
_internal_declare_key_flags(
# As we associate flags with get_calling_module_object_and_name(), the
# special flags defined in this module are incorrectly registered with
# a different module. So, we can't use get_key_flags_for_module.
# Instead, we take all flags from _helpers.SPECIAL_FLAGS (a private
# FlagValues, where no other module should register flags).
[_helpers.SPECIAL_FLAGS[name].name for name in _helpers.SPECIAL_FLAGS],
flag_values=_helpers.SPECIAL_FLAGS,
key_flag_values=flag_values)
def disclaim_key_flags() -> None:
"""Declares that the current module will not define any more key flags.
Normally, the module that calls the DEFINE_xxx functions claims the
flag to be its key flag. This is undesirable for modules that
define additional DEFINE_yyy functions with its own flag parsers and
serializers, since that module will accidentally claim flags defined
by DEFINE_yyy as its key flags. After calling this function, the
module disclaims flag definitions thereafter, so the key flags will
be correctly attributed to the caller of DEFINE_yyy.
After calling this function, the module will not be able to define
any more flags. This function will affect all FlagValues objects.
"""
globals_for_caller = sys._getframe(1).f_globals # pylint: disable=protected-access
module, _ = _helpers.get_module_object_and_name(globals_for_caller)
_helpers.disclaim_module_ids.add(id(module))
@overload
def DEFINE_string( # pylint: disable=invalid-name
name: Text,
default: Optional[Text],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[Text]:
...
@overload
def DEFINE_string( # pylint: disable=invalid-name
name: Text,
default: None,
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[Text]]:
...
@overload
def DEFINE_string( # pylint: disable=invalid-name
name: Text,
default: Text,
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Text]:
...
def DEFINE_string( # pylint: disable=invalid-name,redefined-builtin
name,
default,
help,
flag_values=_flagvalues.FLAGS,
required=False,
**args):
"""Registers a flag whose value can be any string."""
parser = _argument_parser.ArgumentParser[str]()
serializer = _argument_parser.ArgumentSerializer[str]()
return DEFINE(
parser,
name,
default,
help,
flag_values,
serializer,
required=True if required else False,
**args,
)
@overload
def DEFINE_boolean( # pylint: disable=invalid-name
name: Text,
default: Union[None, Text, bool, int],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[bool]:
...
@overload
def DEFINE_boolean( # pylint: disable=invalid-name
name: Text,
default: None,
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[bool]]:
...
@overload
def DEFINE_boolean( # pylint: disable=invalid-name
name: Text,
default: Union[Text, bool, int],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[bool]:
...
def DEFINE_boolean( # pylint: disable=invalid-name,redefined-builtin
name,
default,
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
required=False,
**args):
"""Registers a boolean flag.
Such a boolean flag does not take an argument. If a user wants to
specify a false value explicitly, the long option beginning with 'no'
must be used: i.e. --noflag
This flag will have a value of None, True or False. None is possible
if default=None and the user does not specify the flag on the command
line.
Args:
name: str, the flag name.
default: bool|str|None, the default value of the flag.
help: str, the help message.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: dict, the extra keyword args that are passed to ``Flag.__init__``.
Returns:
a handle to defined flag.
"""
return DEFINE_flag(
_flag.BooleanFlag(name, default, help, **args),
flag_values,
module_name,
required=True if required else False,
)
@overload
def DEFINE_float( # pylint: disable=invalid-name
name: Text,
default: Union[None, float, Text],
help: Optional[Text], # pylint: disable=redefined-builtin
lower_bound: Optional[float] = ...,
upper_bound: Optional[float] = ...,
flag_values: _flagvalues.FlagValues = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[float]:
...
@overload
def DEFINE_float( # pylint: disable=invalid-name
name: Text,
default: None,
help: Optional[Text], # pylint: disable=redefined-builtin
lower_bound: Optional[float] = ...,
upper_bound: Optional[float] = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[float]]:
...
@overload
def DEFINE_float( # pylint: disable=invalid-name
name: Text,
default: Union[float, Text],
help: Optional[Text], # pylint: disable=redefined-builtin
lower_bound: Optional[float] = ...,
upper_bound: Optional[float] = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[float]:
...
def DEFINE_float( # pylint: disable=invalid-name,redefined-builtin
name,
default,
help,
lower_bound=None,
upper_bound=None,
flag_values=_flagvalues.FLAGS,
required=False,
**args):
"""Registers a flag whose value must be a float.
If ``lower_bound`` or ``upper_bound`` are set, then this flag must be
within the given range.
Args:
name: str, the flag name.
default: float|str|None, the default value of the flag.
help: str, the help message.
lower_bound: float, min value of the flag.
upper_bound: float, max value of the flag.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: dict, the extra keyword args that are passed to :func:`DEFINE`.
Returns:
a handle to defined flag.
"""
parser = _argument_parser.FloatParser(lower_bound, upper_bound)
serializer = _argument_parser.ArgumentSerializer()
result = DEFINE(
parser,
name,
default,
help,
flag_values,
serializer,
required=True if required else False,
**args,
)
_register_bounds_validator_if_needed(parser, name, flag_values=flag_values)
return result
@overload
def DEFINE_integer( # pylint: disable=invalid-name
name: Text,
default: Union[None, int, Text],
help: Optional[Text], # pylint: disable=redefined-builtin
lower_bound: Optional[int] = ...,
upper_bound: Optional[int] = ...,
flag_values: _flagvalues.FlagValues = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[int]:
...
@overload
def DEFINE_integer( # pylint: disable=invalid-name
name: Text,
default: None,
help: Optional[Text], # pylint: disable=redefined-builtin
lower_bound: Optional[int] = ...,
upper_bound: Optional[int] = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[int]]:
...
@overload
def DEFINE_integer( # pylint: disable=invalid-name
name: Text,
default: Union[int, Text],
help: Optional[Text], # pylint: disable=redefined-builtin
lower_bound: Optional[int] = ...,
upper_bound: Optional[int] = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[int]:
...
def DEFINE_integer( # pylint: disable=invalid-name,redefined-builtin
name,
default,
help,
lower_bound=None,
upper_bound=None,
flag_values=_flagvalues.FLAGS,
required=False,
**args):
"""Registers a flag whose value must be an integer.
If ``lower_bound``, or ``upper_bound`` are set, then this flag must be
within the given range.
Args:
name: str, the flag name.
default: int|str|None, the default value of the flag.
help: str, the help message.
lower_bound: int, min value of the flag.
upper_bound: int, max value of the flag.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: dict, the extra keyword args that are passed to :func:`DEFINE`.
Returns:
a handle to defined flag.
"""
parser = _argument_parser.IntegerParser(lower_bound, upper_bound)
serializer = _argument_parser.ArgumentSerializer()
result = DEFINE(
parser,
name,
default,
help,
flag_values,
serializer,
required=True if required else False,
**args,
)
_register_bounds_validator_if_needed(parser, name, flag_values=flag_values)
return result
@overload
def DEFINE_enum( # pylint: disable=invalid-name
name: Text,
default: Optional[Text],
enum_values: Iterable[Text],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[Text]:
...
@overload
def DEFINE_enum( # pylint: disable=invalid-name
name: Text,
default: None,
enum_values: Iterable[Text],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[Text]]:
...
@overload
def DEFINE_enum( # pylint: disable=invalid-name
name: Text,
default: Text,
enum_values: Iterable[Text],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Text]:
...
def DEFINE_enum( # pylint: disable=invalid-name,redefined-builtin
name,
default,
enum_values,
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
required=False,
**args):
"""Registers a flag whose value can be any string from enum_values.
Instead of a string enum, prefer `DEFINE_enum_class`, which allows
defining enums from an `enum.Enum` class.
Args:
name: str, the flag name.
default: str|None, the default value of the flag.
enum_values: [str], a non-empty list of strings with the possible values for
the flag.
help: str, the help message.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: dict, the extra keyword args that are passed to ``Flag.__init__``.
Returns:
a handle to defined flag.
"""
result = DEFINE_flag(
_flag.EnumFlag(name, default, help, enum_values, **args),
flag_values,
module_name,
required=True if required else False,
)
return result
@overload
def DEFINE_enum_class( # pylint: disable=invalid-name
name: Text,
default: Union[None, _ET, Text],
enum_class: Type[_ET],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
case_sensitive: bool = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[_ET]:
...
@overload
def DEFINE_enum_class( # pylint: disable=invalid-name
name: Text,
default: None,
enum_class: Type[_ET],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
case_sensitive: bool = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[_ET]]:
...
@overload
def DEFINE_enum_class( # pylint: disable=invalid-name
name: Text,
default: Union[_ET, Text],
enum_class: Type[_ET],
help: Optional[Text], # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
case_sensitive: bool = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[_ET]:
...
def DEFINE_enum_class( # pylint: disable=invalid-name,redefined-builtin
name,
default,
enum_class,
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
case_sensitive=False,
required=False,
**args):
"""Registers a flag whose value can be the name of enum members.
Args:
name: str, the flag name.
default: Enum|str|None, the default value of the flag.
enum_class: class, the Enum class with all the possible values for the flag.
help: str, the help message.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
module_name: str, the name of the Python module declaring this flag. If not
provided, it will be computed using the stack trace of this call.
case_sensitive: bool, whether to map strings to members of the enum_class
without considering case.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: dict, the extra keyword args that are passed to ``Flag.__init__``.
Returns:
a handle to defined flag.
"""
# NOTE: pytype fails if this is a direct return.
result = DEFINE_flag(
_flag.EnumClassFlag(
name, default, help, enum_class, case_sensitive=case_sensitive, **args
),
flag_values,
module_name,
required=True if required else False,
)
return result
@overload
def DEFINE_list( # pylint: disable=invalid-name
name: Text,
default: Union[None, Iterable[Text], Text],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[Text]]:
...
@overload
def DEFINE_list( # pylint: disable=invalid-name
name: Text,
default: None,
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
...
@overload
def DEFINE_list( # pylint: disable=invalid-name
name: Text,
default: Union[Iterable[Text], Text],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[Text]]:
...
def DEFINE_list( # pylint: disable=invalid-name,redefined-builtin
name,
default,
help,
flag_values=_flagvalues.FLAGS,
required=False,
**args):
"""Registers a flag whose value is a comma-separated list of strings.
The flag value is parsed with a CSV parser.
Args:
name: str, the flag name.
default: list|str|None, the default value of the flag.
help: str, the help message.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: Dictionary with extra keyword args that are passed to the
``Flag.__init__``.
Returns:
a handle to defined flag.
"""
parser = _argument_parser.ListParser()
serializer = _argument_parser.CsvListSerializer(',')
return DEFINE(
parser,
name,
default,
help,
flag_values,
serializer,
required=True if required else False,
**args,
)
@overload
def DEFINE_spaceseplist( # pylint: disable=invalid-name
name: Text,
default: Union[None, Iterable[Text], Text],
help: Text, # pylint: disable=redefined-builtin
comma_compat: bool = ...,
flag_values: _flagvalues.FlagValues = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[Text]]:
...
@overload
def DEFINE_spaceseplist( # pylint: disable=invalid-name
name: Text,
default: None,
help: Text, # pylint: disable=redefined-builtin
comma_compat: bool = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
...
@overload
def DEFINE_spaceseplist( # pylint: disable=invalid-name
name: Text,
default: Union[Iterable[Text], Text],
help: Text, # pylint: disable=redefined-builtin
comma_compat: bool = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[Text]]:
...
def DEFINE_spaceseplist( # pylint: disable=invalid-name,redefined-builtin
name,
default,
help,
comma_compat=False,
flag_values=_flagvalues.FLAGS,
required=False,
**args):
"""Registers a flag whose value is a whitespace-separated list of strings.
Any whitespace can be used as a separator.
Args:
name: str, the flag name.
default: list|str|None, the default value of the flag.
help: str, the help message.
comma_compat: bool - Whether to support comma as an additional separator. If
false then only whitespace is supported. This is intended only for
backwards compatibility with flags that used to be comma-separated.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: Dictionary with extra keyword args that are passed to the
``Flag.__init__``.
Returns:
a handle to defined flag.
"""
parser = _argument_parser.WhitespaceSeparatedListParser(
comma_compat=comma_compat)
serializer = _argument_parser.ListSerializer(' ')
return DEFINE(
parser,
name,
default,
help,
flag_values,
serializer,
required=True if required else False,
**args,
)
@overload
def DEFINE_multi( # pylint: disable=invalid-name
parser: _argument_parser.ArgumentParser[_T],
serializer: _argument_parser.ArgumentSerializer[_T],
name: Text,
default: Iterable[_T],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[_T]]:
...
@overload
def DEFINE_multi( # pylint: disable=invalid-name
parser: _argument_parser.ArgumentParser[_T],
serializer: _argument_parser.ArgumentSerializer[_T],
name: Text,
default: Union[None, _T],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[_T]]:
...
@overload
def DEFINE_multi( # pylint: disable=invalid-name
parser: _argument_parser.ArgumentParser[_T],
serializer: _argument_parser.ArgumentSerializer[_T],
name: Text,
default: None,
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[List[_T]]]:
...
@overload
def DEFINE_multi( # pylint: disable=invalid-name
parser: _argument_parser.ArgumentParser[_T],
serializer: _argument_parser.ArgumentSerializer[_T],
name: Text,
default: Iterable[_T],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[_T]]:
...
@overload
def DEFINE_multi( # pylint: disable=invalid-name
parser: _argument_parser.ArgumentParser[_T],
serializer: _argument_parser.ArgumentSerializer[_T],
name: Text,
default: _T,
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[_T]]:
...
def DEFINE_multi( # pylint: disable=invalid-name,redefined-builtin
parser,
serializer,
name,
default,
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
required=False,
**args):
"""Registers a generic MultiFlag that parses its args with a given parser.
Auxiliary function. Normal users should NOT use it directly.
Developers who need to create their own 'Parser' classes for options
which can appear multiple times can call this module function to
register their flags.
Args:
parser: ArgumentParser, used to parse the flag arguments.
serializer: ArgumentSerializer, the flag serializer instance.
name: str, the flag name.
default: Union[Iterable[T], Text, None], the default value of the flag. If
the value is text, it will be parsed as if it was provided from the
command line. If the value is a non-string iterable, it will be iterated
over to create a shallow copy of the values. If it is None, it is left
as-is.
help: str, the help message.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
module_name: A string, the name of the Python module declaring this flag. If
not provided, it will be computed using the stack trace of this call.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: Dictionary with extra keyword args that are passed to the
``Flag.__init__``.
Returns:
a handle to defined flag.
"""
result = DEFINE_flag(
_flag.MultiFlag(parser, serializer, name, default, help, **args),
flag_values,
module_name,
required=True if required else False,
)
return result
@overload
def DEFINE_multi_string( # pylint: disable=invalid-name
name: Text,
default: Union[None, Iterable[Text], Text],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[Text]]:
...
@overload
def DEFINE_multi_string( # pylint: disable=invalid-name
name: Text,
default: None,
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
...
@overload
def DEFINE_multi_string( # pylint: disable=invalid-name
name: Text,
default: Union[Iterable[Text], Text],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[Text]]:
...
def DEFINE_multi_string( # pylint: disable=invalid-name,redefined-builtin
name,
default,
help,
flag_values=_flagvalues.FLAGS,
required=False,
**args):
"""Registers a flag whose value can be a list of any strings.
Use the flag on the command line multiple times to place multiple
string values into the list. The 'default' may be a single string
(which will be converted into a single-element list) or a list of
strings.
Args:
name: str, the flag name.
default: Union[Iterable[Text], Text, None], the default value of the flag;
see :func:`DEFINE_multi`.
help: str, the help message.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: Dictionary with extra keyword args that are passed to the
``Flag.__init__``.
Returns:
a handle to defined flag.
"""
parser = _argument_parser.ArgumentParser()
serializer = _argument_parser.ArgumentSerializer()
return DEFINE_multi(
parser,
serializer,
name,
default,
help,
flag_values,
required=True if required else False,
**args,
)
@overload
def DEFINE_multi_integer( # pylint: disable=invalid-name
name: Text,
default: Union[None, Iterable[int], int, Text],
help: Text, # pylint: disable=redefined-builtin
lower_bound: Optional[int] = ...,
upper_bound: Optional[int] = ...,
flag_values: _flagvalues.FlagValues = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[int]]:
...
@overload
def DEFINE_multi_integer( # pylint: disable=invalid-name
name: Text,
default: None,
help: Text, # pylint: disable=redefined-builtin
lower_bound: Optional[int] = ...,
upper_bound: Optional[int] = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[List[int]]]:
...
@overload
def DEFINE_multi_integer( # pylint: disable=invalid-name
name: Text,
default: Union[Iterable[int], int, Text],
help: Text, # pylint: disable=redefined-builtin
lower_bound: Optional[int] = ...,
upper_bound: Optional[int] = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[int]]:
...
def DEFINE_multi_integer( # pylint: disable=invalid-name,redefined-builtin
name,
default,
help,
lower_bound=None,
upper_bound=None,
flag_values=_flagvalues.FLAGS,
required=False,
**args):
"""Registers a flag whose value can be a list of arbitrary integers.
Use the flag on the command line multiple times to place multiple
integer values into the list. The 'default' may be a single integer
(which will be converted into a single-element list) or a list of
integers.
Args:
name: str, the flag name.
default: Union[Iterable[int], Text, None], the default value of the flag;
see `DEFINE_multi`.
help: str, the help message.
lower_bound: int, min values of the flag.
upper_bound: int, max values of the flag.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: Dictionary with extra keyword args that are passed to the
``Flag.__init__``.
Returns:
a handle to defined flag.
"""
parser = _argument_parser.IntegerParser(lower_bound, upper_bound)
serializer = _argument_parser.ArgumentSerializer()
return DEFINE_multi(
parser,
serializer,
name,
default,
help,
flag_values,
required=True if required else False,
**args,
)
@overload
def DEFINE_multi_float( # pylint: disable=invalid-name
name: Text,
default: Union[None, Iterable[float], float, Text],
help: Text, # pylint: disable=redefined-builtin
lower_bound: Optional[float] = ...,
upper_bound: Optional[float] = ...,
flag_values: _flagvalues.FlagValues = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[float]]:
...
@overload
def DEFINE_multi_float( # pylint: disable=invalid-name
name: Text,
default: None,
help: Text, # pylint: disable=redefined-builtin
lower_bound: Optional[float] = ...,
upper_bound: Optional[float] = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[List[float]]]:
...
@overload
def DEFINE_multi_float( # pylint: disable=invalid-name
name: Text,
default: Union[Iterable[float], float, Text],
help: Text, # pylint: disable=redefined-builtin
lower_bound: Optional[float] = ...,
upper_bound: Optional[float] = ...,
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[float]]:
...
def DEFINE_multi_float( # pylint: disable=invalid-name,redefined-builtin
name,
default,
help,
lower_bound=None,
upper_bound=None,
flag_values=_flagvalues.FLAGS,
required=False,
**args):
"""Registers a flag whose value can be a list of arbitrary floats.
Use the flag on the command line multiple times to place multiple
float values into the list. The 'default' may be a single float
(which will be converted into a single-element list) or a list of
floats.
Args:
name: str, the flag name.
default: Union[Iterable[float], Text, None], the default value of the flag;
see `DEFINE_multi`.
help: str, the help message.
lower_bound: float, min values of the flag.
upper_bound: float, max values of the flag.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: Dictionary with extra keyword args that are passed to the
``Flag.__init__``.
Returns:
a handle to defined flag.
"""
parser = _argument_parser.FloatParser(lower_bound, upper_bound)
serializer = _argument_parser.ArgumentSerializer()
return DEFINE_multi(
parser,
serializer,
name,
default,
help,
flag_values,
required=True if required else False,
**args,
)
@overload
def DEFINE_multi_enum( # pylint: disable=invalid-name
name: Text,
default: Union[None, Iterable[Text], Text],
enum_values: Iterable[Text],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[Text]]:
...
@overload
def DEFINE_multi_enum( # pylint: disable=invalid-name
name: Text,
default: None,
enum_values: Iterable[Text],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[List[Text]]]:
...
@overload
def DEFINE_multi_enum( # pylint: disable=invalid-name
name: Text,
default: Union[Iterable[Text], Text],
enum_values: Iterable[Text],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[Text]]:
...
def DEFINE_multi_enum( # pylint: disable=invalid-name,redefined-builtin
name,
default,
enum_values,
help,
flag_values=_flagvalues.FLAGS,
case_sensitive=True,
required=False,
**args):
"""Registers a flag whose value can be a list strings from enum_values.
Use the flag on the command line multiple times to place multiple
enum values into the list. The 'default' may be a single string
(which will be converted into a single-element list) or a list of
strings.
Args:
name: str, the flag name.
default: Union[Iterable[Text], Text, None], the default value of the flag;
see `DEFINE_multi`.
enum_values: [str], a non-empty list of strings with the possible values for
the flag.
help: str, the help message.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
case_sensitive: Whether or not the enum is to be case-sensitive.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: Dictionary with extra keyword args that are passed to the
``Flag.__init__``.
Returns:
a handle to defined flag.
"""
parser = _argument_parser.EnumParser(enum_values, case_sensitive)
serializer = _argument_parser.ArgumentSerializer()
return DEFINE_multi(
parser,
serializer,
name,
default,
'<%s>: %s' % ('|'.join(enum_values), help),
flag_values,
required=True if required else False,
**args,
)
@overload
def DEFINE_multi_enum_class( # pylint: disable=invalid-name
name: Text,
# This is separate from `Union[None, _ET, Iterable[Text], Text]` to avoid a
# Pytype issue inferring the return value to
# FlagHolder[List[Union[_ET, enum.Enum]]] when an iterable of concrete enum
# subclasses are used.
default: Iterable[_ET],
enum_class: Type[_ET],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[_ET]]:
...
@overload
def DEFINE_multi_enum_class( # pylint: disable=invalid-name
name: Text,
default: Union[None, _ET, Iterable[Text], Text],
enum_class: Type[_ET],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
*,
required: 'typing.Literal[True]',
**args: Any
) -> _flagvalues.FlagHolder[List[_ET]]:
...
@overload
def DEFINE_multi_enum_class( # pylint: disable=invalid-name
name: Text,
default: None,
enum_class: Type[_ET],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[Optional[List[_ET]]]:
...
@overload
def DEFINE_multi_enum_class( # pylint: disable=invalid-name
name: Text,
# This is separate from `Union[None, _ET, Iterable[Text], Text]` to avoid a
# Pytype issue inferring the return value to
# FlagHolder[List[Union[_ET, enum.Enum]]] when an iterable of concrete enum
# subclasses are used.
default: Iterable[_ET],
enum_class: Type[_ET],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[_ET]]:
...
@overload
def DEFINE_multi_enum_class( # pylint: disable=invalid-name
name: Text,
default: Union[_ET, Iterable[Text], Text],
enum_class: Type[_ET],
help: Text, # pylint: disable=redefined-builtin
flag_values: _flagvalues.FlagValues = ...,
module_name: Optional[Text] = ...,
required: bool = ...,
**args: Any
) -> _flagvalues.FlagHolder[List[_ET]]:
...
def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
name,
default,
enum_class,
help,
flag_values=_flagvalues.FLAGS,
module_name=None,
case_sensitive=False,
required=False,
**args):
"""Registers a flag whose value can be a list of enum members.
Use the flag on the command line multiple times to place multiple
enum values into the list.
Args:
name: str, the flag name.
default: Union[Iterable[Enum], Iterable[Text], Enum, Text, None], the
default value of the flag; see `DEFINE_multi`; only differences are
documented here. If the value is a single Enum, it is treated as a
single-item list of that Enum value. If it is an iterable, text values
within the iterable will be converted to the equivalent Enum objects.
enum_class: class, the Enum class with all the possible values for the flag.
help: str, the help message.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
module_name: A string, the name of the Python module declaring this flag. If
not provided, it will be computed using the stack trace of this call.
case_sensitive: bool, whether to map strings to members of the enum_class
without considering case.
required: bool, is this a required flag. This must be used as a keyword
argument.
**args: Dictionary with extra keyword args that are passed to the
``Flag.__init__``.
Returns:
a handle to defined flag.
"""
# NOTE: pytype fails if this is a direct return.
result = DEFINE_flag(
_flag.MultiEnumClassFlag(
name,
default,
help,
enum_class,
case_sensitive=case_sensitive,
**args,
),
flag_values,
module_name,
required=True if required else False,
)
return result
def DEFINE_alias( # pylint: disable=invalid-name
name: Text,
original_name: Text,
flag_values: _flagvalues.FlagValues = _flagvalues.FLAGS,
module_name: Optional[Text] = None,
) -> _flagvalues.FlagHolder[Any]:
"""Defines an alias flag for an existing one.
Args:
name: str, the flag name.
original_name: str, the original flag name.
flag_values: :class:`FlagValues`, the FlagValues instance with which the
flag will be registered. This should almost never need to be overridden.
module_name: A string, the name of the module that defines this flag.
Returns:
a handle to defined flag.
Raises:
flags.FlagError:
UnrecognizedFlagError: if the referenced flag doesn't exist.
DuplicateFlagError: if the alias name has been used by some existing flag.
"""
if original_name not in flag_values:
raise _exceptions.UnrecognizedFlagError(original_name)
flag = flag_values[original_name]
class _FlagAlias(_flag.Flag):
"""Overrides Flag class so alias value is copy of original flag value."""
def parse(self, argument):
flag.parse(argument)
self.present += 1
def _parse_from_default(self, value):
# The value was already parsed by the aliased flag, so there is no
# need to call the parser on it a second time.
# Additionally, because of how MultiFlag parses and merges values,
# it isn't possible to delegate to the aliased flag and still get
# the correct values.
return value
@property
def value(self):
return flag.value
@value.setter
def value(self, value):
flag.value = value
help_msg = 'Alias for --%s.' % flag.name
# If alias_name has been used, flags.DuplicatedFlag will be raised.
return DEFINE_flag(
_FlagAlias(
flag.parser,
flag.serializer,
name,
flag.default,
help_msg,
boolean=flag.boolean), flag_values, module_name)
abseil-py-2.1.0/absl/flags/_exceptions.py 0000664 0000000 0000000 00000007111 14551576331 0020320 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Exception classes in ABSL flags library.
Do NOT import this module directly. Import the flags package and use the
aliases defined at the package level instead.
"""
import sys
from absl.flags import _helpers
_helpers.disclaim_module_ids.add(id(sys.modules[__name__]))
class Error(Exception):
"""The base class for all flags errors."""
class CantOpenFlagFileError(Error):
"""Raised when flagfile fails to open.
E.g. the file doesn't exist, or has wrong permissions.
"""
class DuplicateFlagError(Error):
"""Raised if there is a flag naming conflict."""
@classmethod
def from_flag(cls, flagname, flag_values, other_flag_values=None):
"""Creates a DuplicateFlagError by providing flag name and values.
Args:
flagname: str, the name of the flag being redefined.
flag_values: :class:`FlagValues`, the FlagValues instance containing the
first definition of flagname.
other_flag_values: :class:`FlagValues`, if it is not None, it should be
the FlagValues object where the second definition of flagname occurs.
If it is None, we assume that we're being called when attempting to
create the flag a second time, and we use the module calling this one
as the source of the second definition.
Returns:
An instance of DuplicateFlagError.
"""
first_module = flag_values.find_module_defining_flag(
flagname, default='')
if other_flag_values is None:
second_module = _helpers.get_calling_module()
else:
second_module = other_flag_values.find_module_defining_flag(
flagname, default='')
flag_summary = flag_values[flagname].help
msg = ("The flag '%s' is defined twice. First from %s, Second from %s. "
"Description from first occurrence: %s") % (
flagname, first_module, second_module, flag_summary)
return cls(msg)
class IllegalFlagValueError(Error):
"""Raised when the flag command line argument is illegal."""
class UnrecognizedFlagError(Error):
"""Raised when a flag is unrecognized.
Attributes:
flagname: str, the name of the unrecognized flag.
flagvalue: The value of the flag, empty if the flag is not defined.
"""
def __init__(self, flagname, flagvalue='', suggestions=None):
self.flagname = flagname
self.flagvalue = flagvalue
if suggestions:
# Space before the question mark is intentional to not include it in the
# selection when copy-pasting the suggestion from (some) terminals.
tip = '. Did you mean: %s ?' % ', '.join(suggestions)
else:
tip = ''
super(UnrecognizedFlagError, self).__init__(
'Unknown command line flag \'%s\'%s' % (flagname, tip))
class UnparsedFlagAccessError(Error):
"""Raised when accessing the flag value from unparsed :class:`FlagValues`."""
class ValidationError(Error):
"""Raised when flag validator constraint is not satisfied."""
class FlagNameConflictsWithMethodError(Error):
"""Raised when a flag name conflicts with :class:`FlagValues` methods."""
abseil-py-2.1.0/absl/flags/_flag.py 0000664 0000000 0000000 00000046667 14551576331 0017073 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains Flag class - information about single command-line flag.
Do NOT import this module directly. Import the flags package and use the
aliases defined at the package level instead.
"""
from collections import abc
import copy
import enum
import functools
from typing import Any, Dict, Generic, Iterable, List, Optional, Text, Type, TypeVar, Union
from xml.dom import minidom
from absl.flags import _argument_parser
from absl.flags import _exceptions
from absl.flags import _helpers
_T = TypeVar('_T')
_ET = TypeVar('_ET', bound=enum.Enum)
@functools.total_ordering
class Flag(Generic[_T]):
"""Information about a command-line flag.
Attributes:
name: the name for this flag
default: the default value for this flag
default_unparsed: the unparsed default value for this flag.
default_as_str: default value as repr'd string, e.g., "'true'"
(or None)
value: the most recent parsed value of this flag set by :meth:`parse`
help: a help string or None if no help is available
short_name: the single letter alias for this flag (or None)
boolean: if 'true', this flag does not accept arguments
present: true if this flag was parsed from command line flags
parser: an :class:`~absl.flags.ArgumentParser` object
serializer: an ArgumentSerializer object
allow_override: the flag may be redefined without raising an error,
and newly defined flag overrides the old one.
allow_override_cpp: use the flag from C++ if available the flag
definition is replaced by the C++ flag after init
allow_hide_cpp: use the Python flag despite having a C++ flag with
the same name (ignore the C++ flag)
using_default_value: the flag value has not been set by user
allow_overwrite: the flag may be parsed more than once without
raising an error, the last set value will be used
allow_using_method_names: whether this flag can be defined even if
it has a name that conflicts with a FlagValues method.
validators: list of the flag validators.
The only public method of a ``Flag`` object is :meth:`parse`, but it is
typically only called by a :class:`~absl.flags.FlagValues` object. The
:meth:`parse` method is a thin wrapper around the
:meth:`ArgumentParser.parse()` method. The
parsed value is saved in ``.value``, and the ``.present`` attribute is
updated. If this flag was already present, an Error is raised.
:meth:`parse` is also called during ``__init__`` to parse the default value
and initialize the ``.value`` attribute. This enables other python modules to
safely use flags even if the ``__main__`` module neglects to parse the
command line arguments. The ``.present`` attribute is cleared after
``__init__`` parsing. If the default value is set to ``None``, then the
``__init__`` parsing step is skipped and the ``.value`` attribute is
initialized to None.
Note: The default value is also presented to the user in the help
string, so it is important that it be a legal value for this flag.
"""
# NOTE: pytype doesn't find defaults without this.
default: Optional[_T]
default_as_str: Optional[Text]
default_unparsed: Union[Optional[_T], Text]
def __init__(
self,
parser: _argument_parser.ArgumentParser[_T],
serializer: Optional[_argument_parser.ArgumentSerializer[_T]],
name: Text,
default: Union[Optional[_T], Text],
help_string: Optional[Text],
short_name: Optional[Text] = None,
boolean: bool = False,
allow_override: bool = False,
allow_override_cpp: bool = False,
allow_hide_cpp: bool = False,
allow_overwrite: bool = True,
allow_using_method_names: bool = False,
) -> None:
self.name = name
if not help_string:
help_string = '(no help available)'
self.help = help_string
self.short_name = short_name
self.boolean = boolean
self.present = 0
self.parser = parser
self.serializer = serializer
self.allow_override = allow_override
self.allow_override_cpp = allow_override_cpp
self.allow_hide_cpp = allow_hide_cpp
self.allow_overwrite = allow_overwrite
self.allow_using_method_names = allow_using_method_names
self.using_default_value = True
self._value = None
self.validators = []
if self.allow_hide_cpp and self.allow_override_cpp:
raise _exceptions.Error(
"Can't have both allow_hide_cpp (means use Python flag) and "
'allow_override_cpp (means use C++ flag after InitGoogle)')
self._set_default(default)
@property
def value(self) -> Optional[_T]:
return self._value
@value.setter
def value(self, value: Optional[_T]):
self._value = value
def __hash__(self):
return hash(id(self))
def __eq__(self, other):
return self is other
def __lt__(self, other):
if isinstance(other, Flag):
return id(self) < id(other)
return NotImplemented
def __bool__(self):
raise TypeError('A Flag instance would always be True. '
'Did you mean to test the `.value` attribute?')
def __getstate__(self):
raise TypeError("can't pickle Flag objects")
def __copy__(self):
raise TypeError('%s does not support shallow copies. '
'Use copy.deepcopy instead.' % type(self).__name__)
def __deepcopy__(self, memo: Dict[int, Any]) -> 'Flag[_T]':
result = object.__new__(type(self))
result.__dict__ = copy.deepcopy(self.__dict__, memo)
return result
def _get_parsed_value_as_string(self, value: Optional[_T]) -> Optional[Text]:
"""Returns parsed flag value as string."""
if value is None:
return None
if self.serializer:
return repr(self.serializer.serialize(value))
if self.boolean:
if value:
return repr('true')
else:
return repr('false')
return repr(str(value))
def parse(self, argument: Union[Text, Optional[_T]]) -> None:
"""Parses string and sets flag value.
Args:
argument: str or the correct flag value type, argument to be parsed.
"""
if self.present and not self.allow_overwrite:
raise _exceptions.IllegalFlagValueError(
'flag --%s=%s: already defined as %s' % (
self.name, argument, self.value))
self.value = self._parse(argument)
self.present += 1
def _parse(self, argument: Union[Text, _T]) -> Optional[_T]:
"""Internal parse function.
It returns the parsed value, and does not modify class states.
Args:
argument: str or the correct flag value type, argument to be parsed.
Returns:
The parsed value.
"""
try:
return self.parser.parse(argument)
except (TypeError, ValueError) as e: # Recast as IllegalFlagValueError.
raise _exceptions.IllegalFlagValueError(
'flag --%s=%s: %s' % (self.name, argument, e))
def unparse(self) -> None:
self.value = self.default
self.using_default_value = True
self.present = 0
def serialize(self) -> Text:
"""Serializes the flag."""
return self._serialize(self.value)
def _serialize(self, value: Optional[_T]) -> Text:
"""Internal serialize function."""
if value is None:
return ''
if self.boolean:
if value:
return '--%s' % self.name
else:
return '--no%s' % self.name
else:
if not self.serializer:
raise _exceptions.Error(
'Serializer not present for flag %s' % self.name)
return '--%s=%s' % (self.name, self.serializer.serialize(value))
def _set_default(self, value: Union[Optional[_T], Text]) -> None:
"""Changes the default value (and current value too) for this Flag."""
self.default_unparsed = value
if value is None:
self.default = None
else:
self.default = self._parse_from_default(value)
self.default_as_str = self._get_parsed_value_as_string(self.default)
if self.using_default_value:
self.value = self.default
# This is split out so that aliases can skip regular parsing of the default
# value.
def _parse_from_default(self, value: Union[Text, _T]) -> Optional[_T]:
return self._parse(value)
def flag_type(self) -> Text:
"""Returns a str that describes the type of the flag.
NOTE: we use strings, and not the types.*Type constants because
our flags can have more exotic types, e.g., 'comma separated list
of strings', 'whitespace separated list of strings', etc.
"""
return self.parser.flag_type()
def _create_xml_dom_element(
self, doc: minidom.Document, module_name: str, is_key: bool = False
) -> minidom.Element:
"""Returns an XML element that contains this flag's information.
This is information that is relevant to all flags (e.g., name,
meaning, etc.). If you defined a flag that has some other pieces of
info, then please override _ExtraXMLInfo.
Please do NOT override this method.
Args:
doc: minidom.Document, the DOM document it should create nodes from.
module_name: str,, the name of the module that defines this flag.
is_key: boolean, True iff this flag is key for main module.
Returns:
A minidom.Element instance.
"""
element = doc.createElement('flag')
if is_key:
element.appendChild(_helpers.create_xml_dom_element(doc, 'key', 'yes'))
element.appendChild(_helpers.create_xml_dom_element(
doc, 'file', module_name))
# Adds flag features that are relevant for all flags.
element.appendChild(_helpers.create_xml_dom_element(doc, 'name', self.name))
if self.short_name:
element.appendChild(_helpers.create_xml_dom_element(
doc, 'short_name', self.short_name))
if self.help:
element.appendChild(_helpers.create_xml_dom_element(
doc, 'meaning', self.help))
# The default flag value can either be represented as a string like on the
# command line, or as a Python object. We serialize this value in the
# latter case in order to remain consistent.
if self.serializer and not isinstance(self.default, str):
if self.default is not None:
default_serialized = self.serializer.serialize(self.default)
else:
default_serialized = ''
else:
default_serialized = self.default
element.appendChild(_helpers.create_xml_dom_element(
doc, 'default', default_serialized))
value_serialized = self._serialize_value_for_xml(self.value)
element.appendChild(_helpers.create_xml_dom_element(
doc, 'current', value_serialized))
element.appendChild(_helpers.create_xml_dom_element(
doc, 'type', self.flag_type()))
# Adds extra flag features this flag may have.
for e in self._extra_xml_dom_elements(doc):
element.appendChild(e)
return element
def _serialize_value_for_xml(self, value: Optional[_T]) -> Any:
"""Returns the serialized value, for use in an XML help text."""
return value
def _extra_xml_dom_elements(
self, doc: minidom.Document
) -> List[minidom.Element]:
"""Returns extra info about this flag in XML.
"Extra" means "not already included by _create_xml_dom_element above."
Args:
doc: minidom.Document, the DOM document it should create nodes from.
Returns:
A list of minidom.Element.
"""
# Usually, the parser knows the extra details about the flag, so
# we just forward the call to it.
return self.parser._custom_xml_dom_elements(doc) # pylint: disable=protected-access
class BooleanFlag(Flag[bool]):
"""Basic boolean flag.
Boolean flags do not take any arguments, and their value is either
``True`` (1) or ``False`` (0). The false value is specified on the command
line by prepending the word ``'no'`` to either the long or the short flag
name.
For example, if a Boolean flag was created whose long name was
``'update'`` and whose short name was ``'x'``, then this flag could be
explicitly unset through either ``--noupdate`` or ``--nox``.
"""
def __init__(
self,
name: Text,
default: Union[Optional[bool], Text],
help: Optional[Text], # pylint: disable=redefined-builtin
short_name: Optional[Text] = None,
**args
) -> None:
p = _argument_parser.BooleanParser()
super(BooleanFlag, self).__init__(
p, None, name, default, help, short_name, True, **args
)
class EnumFlag(Flag[Text]):
"""Basic enum flag; its value can be any string from list of enum_values."""
def __init__(
self,
name: Text,
default: Optional[Text],
help: Optional[Text], # pylint: disable=redefined-builtin
enum_values: Iterable[Text],
short_name: Optional[Text] = None,
case_sensitive: bool = True,
**args
):
p = _argument_parser.EnumParser(enum_values, case_sensitive)
g = _argument_parser.ArgumentSerializer()
super(EnumFlag, self).__init__(
p, g, name, default, help, short_name, **args)
# NOTE: parser should be typed EnumParser but the constructor
# restricts the available interface to ArgumentParser[str].
self.parser = p
self.help = '<%s>: %s' % ('|'.join(p.enum_values), self.help)
def _extra_xml_dom_elements(
self, doc: minidom.Document
) -> List[minidom.Element]:
elements = []
for enum_value in self.parser.enum_values:
elements.append(_helpers.create_xml_dom_element(
doc, 'enum_value', enum_value))
return elements
class EnumClassFlag(Flag[_ET]):
"""Basic enum flag; its value is an enum class's member."""
def __init__(
self,
name: Text,
default: Union[Optional[_ET], Text],
help: Optional[Text], # pylint: disable=redefined-builtin
enum_class: Type[_ET],
short_name: Optional[Text] = None,
case_sensitive: bool = False,
**args
):
p = _argument_parser.EnumClassParser(
enum_class, case_sensitive=case_sensitive)
g = _argument_parser.EnumClassSerializer(lowercase=not case_sensitive)
super(EnumClassFlag, self).__init__(
p, g, name, default, help, short_name, **args)
# NOTE: parser should be typed EnumClassParser[_ET] but the constructor
# restricts the available interface to ArgumentParser[_ET].
self.parser = p
self.help = '<%s>: %s' % ('|'.join(p.member_names), self.help)
def _extra_xml_dom_elements(
self, doc: minidom.Document
) -> List[minidom.Element]:
elements = []
for enum_value in self.parser.enum_class.__members__.keys():
elements.append(_helpers.create_xml_dom_element(
doc, 'enum_value', enum_value))
return elements
class MultiFlag(Generic[_T], Flag[List[_T]]):
"""A flag that can appear multiple time on the command-line.
The value of such a flag is a list that contains the individual values
from all the appearances of that flag on the command-line.
See the __doc__ for Flag for most behavior of this class. Only
differences in behavior are described here:
* The default value may be either a single value or an iterable of values.
A single value is transformed into a single-item list of that value.
* The value of the flag is always a list, even if the option was
only supplied once, and even if the default value is a single
value
"""
def __init__(self, *args, **kwargs):
super(MultiFlag, self).__init__(*args, **kwargs)
self.help += ';\n repeat this option to specify a list of values'
def parse(self, arguments: Union[Text, _T, Iterable[_T]]): # pylint: disable=arguments-renamed
"""Parses one or more arguments with the installed parser.
Args:
arguments: a single argument or a list of arguments (typically a
list of default values); a single argument is converted
internally into a list containing one item.
"""
new_values = self._parse(arguments)
if self.present:
self.value.extend(new_values)
else:
self.value = new_values
self.present += len(new_values)
def _parse(self, arguments: Union[Text, Optional[Iterable[_T]]]) -> List[_T]: # pylint: disable=arguments-renamed
if (isinstance(arguments, abc.Iterable) and
not isinstance(arguments, str)):
arguments = list(arguments)
if not isinstance(arguments, list):
# Default value may be a list of values. Most other arguments
# will not be, so convert them into a single-item list to make
# processing simpler below.
arguments = [arguments]
return [super(MultiFlag, self)._parse(item) for item in arguments]
def _serialize(self, value: Optional[List[_T]]) -> Text:
"""See base class."""
if not self.serializer:
raise _exceptions.Error(
'Serializer not present for flag %s' % self.name)
if value is None:
return ''
serialized_items = [
super(MultiFlag, self)._serialize(value_item) for value_item in value
]
return '\n'.join(serialized_items)
def flag_type(self):
"""See base class."""
return 'multi ' + self.parser.flag_type()
def _extra_xml_dom_elements(
self, doc: minidom.Document
) -> List[minidom.Element]:
elements = []
if hasattr(self.parser, 'enum_values'):
for enum_value in self.parser.enum_values: # pytype: disable=attribute-error
elements.append(_helpers.create_xml_dom_element(
doc, 'enum_value', enum_value))
return elements
class MultiEnumClassFlag(MultiFlag[_ET]): # pytype: disable=not-indexable
"""A multi_enum_class flag.
See the __doc__ for MultiFlag for most behaviors of this class. In addition,
this class knows how to handle enum.Enum instances as values for this flag
type.
"""
def __init__(
self,
name: str,
default: Union[None, Iterable[_ET], _ET, Iterable[Text], Text],
help_string: str,
enum_class: Type[_ET],
case_sensitive: bool = False,
**args
):
p = _argument_parser.EnumClassParser(
enum_class, case_sensitive=case_sensitive)
g = _argument_parser.EnumClassListSerializer(
list_sep=',', lowercase=not case_sensitive)
super(MultiEnumClassFlag, self).__init__(
p, g, name, default, help_string, **args)
# NOTE: parser should be typed EnumClassParser[_ET] but the constructor
# restricts the available interface to ArgumentParser[str].
self.parser = p
# NOTE: serializer should be non-Optional but this isn't inferred.
self.serializer = g
self.help = (
'<%s>: %s;\n repeat this option to specify a list of values' %
('|'.join(p.member_names), help_string or '(no help available)'))
def _extra_xml_dom_elements(
self, doc: minidom.Document
) -> List[minidom.Element]:
elements = []
for enum_value in self.parser.enum_class.__members__.keys(): # pytype: disable=attribute-error
elements.append(_helpers.create_xml_dom_element(
doc, 'enum_value', enum_value))
return elements
def _serialize_value_for_xml(self, value):
"""See base class."""
if value is not None:
if not self.serializer:
raise _exceptions.Error(
'Serializer not present for flag %s' % self.name
)
value_serialized = self.serializer.serialize(value)
else:
value_serialized = ''
return value_serialized
abseil-py-2.1.0/absl/flags/_flagvalues.py 0000664 0000000 0000000 00000151557 14551576331 0020306 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the FlagValues class - registry of 'Flag' objects.
Do NOT import this module directly. Import the flags package and use the
aliases defined at the package level instead.
"""
import copy
import itertools
import logging
import os
import sys
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Sequence, Text, TextIO, Generic, TypeVar, Union, Tuple
from xml.dom import minidom
from absl.flags import _exceptions
from absl.flags import _flag
from absl.flags import _helpers
from absl.flags import _validators_classes
from absl.flags._flag import Flag
# Add flagvalues module to disclaimed module ids.
_helpers.disclaim_module_ids.add(id(sys.modules[__name__]))
_T = TypeVar('_T')
class FlagValues:
"""Registry of :class:`~absl.flags.Flag` objects.
A :class:`FlagValues` can then scan command line arguments, passing flag
arguments through to the 'Flag' objects that it owns. It also
provides easy access to the flag values. Typically only one
:class:`FlagValues` object is needed by an application:
:const:`FLAGS`.
This class is heavily overloaded:
:class:`Flag` objects are registered via ``__setitem__``::
FLAGS['longname'] = x # register a new flag
The ``.value`` attribute of the registered :class:`~absl.flags.Flag` objects
can be accessed as attributes of this :class:`FlagValues` object, through
``__getattr__``. Both the long and short name of the original
:class:`~absl.flags.Flag` objects can be used to access its value::
FLAGS.longname # parsed flag value
FLAGS.x # parsed flag value (short name)
Command line arguments are scanned and passed to the registered
:class:`~absl.flags.Flag` objects through the ``__call__`` method. Unparsed
arguments, including ``argv[0]`` (e.g. the program name) are returned::
argv = FLAGS(sys.argv) # scan command line arguments
The original registered :class:`~absl.flags.Flag` objects can be retrieved
through the use of the dictionary-like operator, ``__getitem__``::
x = FLAGS['longname'] # access the registered Flag object
The ``str()`` operator of a :class:`absl.flags.FlagValues` object provides
help for all of the registered :class:`~absl.flags.Flag` objects.
"""
_HAS_DYNAMIC_ATTRIBUTES = True
# A note on collections.abc.Mapping:
# FlagValues defines __getitem__, __iter__, and __len__. It makes perfect
# sense to let it be a collections.abc.Mapping class. However, we are not
# able to do so. The mixin methods, e.g. keys, values, are not uncommon flag
# names. Those flag values would not be accessible via the FLAGS.xxx form.
__dict__: Dict[str, Any]
def __init__(self):
# Since everything in this class is so heavily overloaded, the only
# way of defining and using fields is to access __dict__ directly.
# Dictionary: flag name (string) -> Flag object.
self.__dict__['__flags'] = {}
# Set: name of hidden flag (string).
# Holds flags that should not be directly accessible from Python.
self.__dict__['__hiddenflags'] = set()
# Dictionary: module name (string) -> list of Flag objects that are defined
# by that module.
self.__dict__['__flags_by_module'] = {}
# Dictionary: module id (int) -> list of Flag objects that are defined by
# that module.
self.__dict__['__flags_by_module_id'] = {}
# Dictionary: module name (string) -> list of Flag objects that are
# key for that module.
self.__dict__['__key_flags_by_module'] = {}
# Bool: True if flags were parsed.
self.__dict__['__flags_parsed'] = False
# Bool: True if unparse_flags() was called.
self.__dict__['__unparse_flags_called'] = False
# None or Method(name, value) to call from __setattr__ for an unknown flag.
self.__dict__['__set_unknown'] = None
# A set of banned flag names. This is to prevent users from accidentally
# defining a flag that has the same name as a method on this class.
# Users can still allow defining the flag by passing
# allow_using_method_names=True in DEFINE_xxx functions.
self.__dict__['__banned_flag_names'] = frozenset(dir(FlagValues))
# Bool: Whether to use GNU style scanning.
self.__dict__['__use_gnu_getopt'] = True
# Bool: Whether use_gnu_getopt has been explicitly set by the user.
self.__dict__['__use_gnu_getopt_explicitly_set'] = False
# Function: Takes a flag name as parameter, returns a tuple
# (is_retired, type_is_bool).
self.__dict__['__is_retired_flag_func'] = None
def set_gnu_getopt(self, gnu_getopt: bool = True) -> None:
"""Sets whether or not to use GNU style scanning.
GNU style allows mixing of flag and non-flag arguments. See
http://docs.python.org/library/getopt.html#getopt.gnu_getopt
Args:
gnu_getopt: bool, whether or not to use GNU style scanning.
"""
self.__dict__['__use_gnu_getopt'] = gnu_getopt
self.__dict__['__use_gnu_getopt_explicitly_set'] = True
def is_gnu_getopt(self) -> bool:
return self.__dict__['__use_gnu_getopt']
def _flags(self) -> Dict[Text, Flag]:
return self.__dict__['__flags']
def flags_by_module_dict(self) -> Dict[Text, List[Flag]]:
"""Returns the dictionary of module_name -> list of defined flags.
Returns:
A dictionary. Its keys are module names (strings). Its values
are lists of Flag objects.
"""
return self.__dict__['__flags_by_module']
def flags_by_module_id_dict(self) -> Dict[int, List[Flag]]:
"""Returns the dictionary of module_id -> list of defined flags.
Returns:
A dictionary. Its keys are module IDs (ints). Its values
are lists of Flag objects.
"""
return self.__dict__['__flags_by_module_id']
def key_flags_by_module_dict(self) -> Dict[Text, List[Flag]]:
"""Returns the dictionary of module_name -> list of key flags.
Returns:
A dictionary. Its keys are module names (strings). Its values
are lists of Flag objects.
"""
return self.__dict__['__key_flags_by_module']
def register_flag_by_module(self, module_name: Text, flag: Flag) -> None:
"""Records the module that defines a specific flag.
We keep track of which flag is defined by which module so that we
can later sort the flags by module.
Args:
module_name: str, the name of a Python module.
flag: Flag, the Flag instance that is key to the module.
"""
flags_by_module = self.flags_by_module_dict()
flags_by_module.setdefault(module_name, []).append(flag)
def register_flag_by_module_id(self, module_id: int, flag: Flag) -> None:
"""Records the module that defines a specific flag.
Args:
module_id: int, the ID of the Python module.
flag: Flag, the Flag instance that is key to the module.
"""
flags_by_module_id = self.flags_by_module_id_dict()
flags_by_module_id.setdefault(module_id, []).append(flag)
def register_key_flag_for_module(self, module_name: Text, flag: Flag) -> None:
"""Specifies that a flag is a key flag for a module.
Args:
module_name: str, the name of a Python module.
flag: Flag, the Flag instance that is key to the module.
"""
key_flags_by_module = self.key_flags_by_module_dict()
# The list of key flags for the module named module_name.
key_flags = key_flags_by_module.setdefault(module_name, [])
# Add flag, but avoid duplicates.
if flag not in key_flags:
key_flags.append(flag)
def _flag_is_registered(self, flag_obj: Flag) -> bool:
"""Checks whether a Flag object is registered under long name or short name.
Args:
flag_obj: Flag, the Flag instance to check for.
Returns:
bool, True iff flag_obj is registered under long name or short name.
"""
flag_dict = self._flags()
# Check whether flag_obj is registered under its long name.
name = flag_obj.name
if flag_dict.get(name, None) == flag_obj:
return True
# Check whether flag_obj is registered under its short name.
short_name = flag_obj.short_name
if (short_name is not None and flag_dict.get(short_name, None) == flag_obj):
return True
return False
def _cleanup_unregistered_flag_from_module_dicts(
self, flag_obj: Flag
) -> None:
"""Cleans up unregistered flags from all module -> [flags] dictionaries.
If flag_obj is registered under either its long name or short name, it
won't be removed from the dictionaries.
Args:
flag_obj: Flag, the Flag instance to clean up for.
"""
if self._flag_is_registered(flag_obj):
return
for flags_by_module_dict in (self.flags_by_module_dict(),
self.flags_by_module_id_dict(),
self.key_flags_by_module_dict()):
for flags_in_module in flags_by_module_dict.values():
# While (as opposed to if) takes care of multiple occurrences of a
# flag in the list for the same module.
while flag_obj in flags_in_module:
flags_in_module.remove(flag_obj)
def get_flags_for_module(self, module: Union[Text, Any]) -> List[Flag]:
"""Returns the list of flags defined by a module.
Args:
module: module|str, the module to get flags from.
Returns:
[Flag], a new list of Flag instances. Caller may update this list as
desired: none of those changes will affect the internals of this
FlagValue instance.
"""
if not isinstance(module, str):
module = module.__name__
if module == '__main__':
module = sys.argv[0]
return list(self.flags_by_module_dict().get(module, []))
def get_key_flags_for_module(self, module: Union[Text, Any]) -> List[Flag]:
"""Returns the list of key flags for a module.
Args:
module: module|str, the module to get key flags from.
Returns:
[Flag], a new list of Flag instances. Caller may update this list as
desired: none of those changes will affect the internals of this
FlagValue instance.
"""
if not isinstance(module, str):
module = module.__name__
if module == '__main__':
module = sys.argv[0]
# Any flag is a key flag for the module that defined it. NOTE:
# key_flags is a fresh list: we can update it without affecting the
# internals of this FlagValues object.
key_flags = self.get_flags_for_module(module)
# Take into account flags explicitly declared as key for a module.
for flag in self.key_flags_by_module_dict().get(module, []):
if flag not in key_flags:
key_flags.append(flag)
return key_flags
# TODO(yileiyang): Restrict default to Optional[Text].
def find_module_defining_flag(
self, flagname: Text, default: Optional[_T] = None
) -> Union[str, Optional[_T]]:
"""Return the name of the module defining this flag, or default.
Args:
flagname: str, name of the flag to lookup.
default: Value to return if flagname is not defined. Defaults to None.
Returns:
The name of the module which registered the flag with this name.
If no such module exists (i.e. no flag with this name exists),
we return default.
"""
registered_flag = self._flags().get(flagname)
if registered_flag is None:
return default
for module, flags in self.flags_by_module_dict().items():
for flag in flags:
# It must compare the flag with the one in _flags. This is because a
# flag might be overridden only for its long name (or short name),
# and only its short name (or long name) is considered registered.
if (flag.name == registered_flag.name and
flag.short_name == registered_flag.short_name):
return module
return default
# TODO(yileiyang): Restrict default to Optional[Text].
def find_module_id_defining_flag(
self, flagname: Text, default: Optional[_T] = None
) -> Union[int, Optional[_T]]:
"""Return the ID of the module defining this flag, or default.
Args:
flagname: str, name of the flag to lookup.
default: Value to return if flagname is not defined. Defaults to None.
Returns:
The ID of the module which registered the flag with this name.
If no such module exists (i.e. no flag with this name exists),
we return default.
"""
registered_flag = self._flags().get(flagname)
if registered_flag is None:
return default
for module_id, flags in self.flags_by_module_id_dict().items():
for flag in flags:
# It must compare the flag with the one in _flags. This is because a
# flag might be overridden only for its long name (or short name),
# and only its short name (or long name) is considered registered.
if (flag.name == registered_flag.name and
flag.short_name == registered_flag.short_name):
return module_id
return default
def _register_unknown_flag_setter(
self, setter: Callable[[str, Any], None]
) -> None:
"""Allow set default values for undefined flags.
Args:
setter: Method(name, value) to call to __setattr__ an unknown flag. Must
raise NameError or ValueError for invalid name/value.
"""
self.__dict__['__set_unknown'] = setter
def _set_unknown_flag(self, name: str, value: _T) -> _T:
"""Returns value if setting flag |name| to |value| returned True.
Args:
name: str, name of the flag to set.
value: Value to set.
Returns:
Flag value on successful call.
Raises:
UnrecognizedFlagError
IllegalFlagValueError
"""
setter = self.__dict__['__set_unknown']
if setter:
try:
setter(name, value)
return value
except (TypeError, ValueError): # Flag value is not valid.
raise _exceptions.IllegalFlagValueError(
'"{1}" is not valid for --{0}'.format(name, value))
except NameError: # Flag name is not valid.
pass
raise _exceptions.UnrecognizedFlagError(name, value)
def append_flag_values(self, flag_values: 'FlagValues') -> None:
"""Appends flags registered in another FlagValues instance.
Args:
flag_values: FlagValues, the FlagValues instance from which to copy flags.
"""
for flag_name, flag in flag_values._flags().items(): # pylint: disable=protected-access
# Each flags with short_name appears here twice (once under its
# normal name, and again with its short name). To prevent
# problems (DuplicateFlagError) with double flag registration, we
# perform a check to make sure that the entry we're looking at is
# for its normal name.
if flag_name == flag.name:
try:
self[flag_name] = flag
except _exceptions.DuplicateFlagError:
raise _exceptions.DuplicateFlagError.from_flag(
flag_name, self, other_flag_values=flag_values)
def remove_flag_values(
self, flag_values: 'Union[FlagValues, Iterable[Text]]'
) -> None:
"""Remove flags that were previously appended from another FlagValues.
Args:
flag_values: FlagValues, the FlagValues instance containing flags to
remove.
"""
for flag_name in flag_values:
self.__delattr__(flag_name)
def __setitem__(self, name: Text, flag: Flag) -> None:
"""Registers a new flag variable."""
fl = self._flags()
if not isinstance(flag, _flag.Flag):
raise _exceptions.IllegalFlagValueError(
f'Expect Flag instances, found type {type(flag)}. '
"Maybe you didn't mean to use FlagValue.__setitem__?")
if not isinstance(name, str):
raise _exceptions.Error('Flag name must be a string')
if not name:
raise _exceptions.Error('Flag name cannot be empty')
if ' ' in name:
raise _exceptions.Error('Flag name cannot contain a space')
self._check_method_name_conflicts(name, flag)
if name in fl and not flag.allow_override and not fl[name].allow_override:
module, module_name = _helpers.get_calling_module_object_and_name()
if (self.find_module_defining_flag(name) == module_name and
id(module) != self.find_module_id_defining_flag(name)):
# If the flag has already been defined by a module with the same name,
# but a different ID, we can stop here because it indicates that the
# module is simply being imported a subsequent time.
return
raise _exceptions.DuplicateFlagError.from_flag(name, self)
# If a new flag overrides an old one, we need to cleanup the old flag's
# modules if it's not registered.
flags_to_cleanup = set()
short_name: str = flag.short_name # pytype: disable=annotation-type-mismatch
if short_name is not None:
if (short_name in fl and not flag.allow_override and
not fl[short_name].allow_override):
raise _exceptions.DuplicateFlagError.from_flag(short_name, self)
if short_name in fl and fl[short_name] != flag:
flags_to_cleanup.add(fl[short_name])
fl[short_name] = flag
if (name not in fl # new flag
or fl[name].using_default_value or not flag.using_default_value):
if name in fl and fl[name] != flag:
flags_to_cleanup.add(fl[name])
fl[name] = flag
for f in flags_to_cleanup:
self._cleanup_unregistered_flag_from_module_dicts(f)
def __dir__(self) -> List[Text]:
"""Returns list of names of all defined flags.
Useful for TAB-completion in ipython.
Returns:
[str], a list of names of all defined flags.
"""
return sorted(self.__dict__['__flags'])
def __getitem__(self, name: Text) -> Flag:
"""Returns the Flag object for the flag --name."""
return self._flags()[name]
def _hide_flag(self, name):
"""Marks the flag --name as hidden."""
self.__dict__['__hiddenflags'].add(name)
def __getattr__(self, name: Text) -> Any:
"""Retrieves the 'value' attribute of the flag --name."""
fl = self._flags()
if name not in fl:
raise AttributeError(name)
if name in self.__dict__['__hiddenflags']:
raise AttributeError(name)
if self.__dict__['__flags_parsed'] or fl[name].present:
return fl[name].value
else:
raise _exceptions.UnparsedFlagAccessError(
'Trying to access flag --%s before flags were parsed.' % name)
def __setattr__(self, name: Text, value: _T) -> _T:
"""Sets the 'value' attribute of the flag --name."""
self._set_attributes(**{name: value})
return value
def _set_attributes(self, **attributes: Any) -> None:
"""Sets multiple flag values together, triggers validators afterwards."""
fl = self._flags()
known_flag_vals = {}
known_flag_used_defaults = {}
try:
for name, value in attributes.items():
if name in self.__dict__['__hiddenflags']:
raise AttributeError(name)
if name in fl:
orig = fl[name].value
fl[name].value = value
known_flag_vals[name] = orig
else:
self._set_unknown_flag(name, value)
for name in known_flag_vals:
self._assert_validators(fl[name].validators)
known_flag_used_defaults[name] = fl[name].using_default_value
fl[name].using_default_value = False
except:
for name, orig in known_flag_vals.items():
fl[name].value = orig
for name, orig in known_flag_used_defaults.items():
fl[name].using_default_value = orig
# NOTE: We do not attempt to undo unknown flag side effects because we
# cannot reliably undo the user-configured behavior.
raise
def validate_all_flags(self) -> None:
"""Verifies whether all flags pass validation.
Raises:
AttributeError: Raised if validators work with a non-existing flag.
IllegalFlagValueError: Raised if validation fails for at least one
validator.
"""
all_validators = set()
for flag in self._flags().values():
all_validators.update(flag.validators)
self._assert_validators(all_validators)
def _assert_validators(
self, validators: Iterable[_validators_classes.Validator]
) -> None:
"""Asserts if all validators in the list are satisfied.
It asserts validators in the order they were created.
Args:
validators: Iterable(validators.Validator), validators to be verified.
Raises:
AttributeError: Raised if validators work with a non-existing flag.
IllegalFlagValueError: Raised if validation fails for at least one
validator.
"""
messages = []
bad_flags = set()
for validator in sorted(
validators, key=lambda validator: validator.insertion_index):
try:
if isinstance(validator, _validators_classes.SingleFlagValidator):
if validator.flag_name in bad_flags:
continue
elif isinstance(validator, _validators_classes.MultiFlagsValidator):
if bad_flags & set(validator.flag_names):
continue
validator.verify(self)
except _exceptions.ValidationError as e:
if isinstance(validator, _validators_classes.SingleFlagValidator):
bad_flags.add(validator.flag_name)
elif isinstance(validator, _validators_classes.MultiFlagsValidator):
bad_flags.update(set(validator.flag_names))
message = validator.print_flags_with_values(self)
messages.append('%s: %s' % (message, str(e)))
if messages:
raise _exceptions.IllegalFlagValueError('\n'.join(messages))
def __delattr__(self, flag_name: Text) -> None:
"""Deletes a previously-defined flag from a flag object.
This method makes sure we can delete a flag by using
del FLAGS.
E.g.,
flags.DEFINE_integer('foo', 1, 'Integer flag.')
del flags.FLAGS.foo
If a flag is also registered by its the other name (long name or short
name), the other name won't be deleted.
Args:
flag_name: str, the name of the flag to be deleted.
Raises:
AttributeError: Raised when there is no registered flag named flag_name.
"""
fl = self._flags()
if flag_name not in fl:
raise AttributeError(flag_name)
flag_obj = fl[flag_name]
del fl[flag_name]
self._cleanup_unregistered_flag_from_module_dicts(flag_obj)
def set_default(self, name: Text, value: Any) -> None:
"""Changes the default value of the named flag object.
The flag's current value is also updated if the flag is currently using
the default value, i.e. not specified in the command line, and not set
by FLAGS.name = value.
Args:
name: str, the name of the flag to modify.
value: The new default value.
Raises:
UnrecognizedFlagError: Raised when there is no registered flag named name.
IllegalFlagValueError: Raised when value is not valid.
"""
fl = self._flags()
if name not in fl:
self._set_unknown_flag(name, value)
return
fl[name]._set_default(value) # pylint: disable=protected-access
self._assert_validators(fl[name].validators)
def __contains__(self, name: Text) -> bool:
"""Returns True if name is a value (flag) in the dict."""
return name in self._flags()
def __len__(self) -> int:
return len(self.__dict__['__flags'])
def __iter__(self) -> Iterator[Text]:
return iter(self._flags())
def __call__(
self, argv: Sequence[Text], known_only: bool = False
) -> List[Text]:
"""Parses flags from argv; stores parsed flags into this FlagValues object.
All unparsed arguments are returned.
Args:
argv: a tuple/list of strings.
known_only: bool, if True, parse and remove known flags; return the rest
untouched. Unknown flags specified by --undefok are not returned.
Returns:
The list of arguments not parsed as options, including argv[0].
Raises:
Error: Raised on any parsing error.
TypeError: Raised on passing wrong type of arguments.
ValueError: Raised on flag value parsing error.
"""
if isinstance(argv, (str, bytes)):
raise TypeError(
'argv should be a tuple/list of strings, not bytes or string.')
if not argv:
raise ValueError(
'argv cannot be an empty list, and must contain the program name as '
'the first element.')
# This pre parses the argv list for --flagfile=<> options.
program_name = argv[0]
args = self.read_flags_from_files(argv[1:], force_gnu=False)
# Parse the arguments.
unknown_flags, unparsed_args = self._parse_args(args, known_only)
# Handle unknown flags by raising UnrecognizedFlagError.
# Note some users depend on us raising this particular error.
for name, value in unknown_flags:
suggestions = _helpers.get_flag_suggestions(name, list(self))
raise _exceptions.UnrecognizedFlagError(
name, value, suggestions=suggestions)
self.mark_as_parsed()
self.validate_all_flags()
return [program_name] + unparsed_args
def __getstate__(self) -> Any:
raise TypeError("can't pickle FlagValues")
def __copy__(self) -> Any:
raise TypeError('FlagValues does not support shallow copies. '
'Use absl.testing.flagsaver or copy.deepcopy instead.')
def __deepcopy__(self, memo) -> Any:
result = object.__new__(type(self))
result.__dict__.update(copy.deepcopy(self.__dict__, memo))
return result
def _set_is_retired_flag_func(self, is_retired_flag_func):
"""Sets a function for checking retired flags.
Do not use it. This is a private absl API used to check retired flags
registered by the absl C++ flags library.
Args:
is_retired_flag_func: Callable(str) -> (bool, bool), a function takes flag
name as parameter, returns a tuple (is_retired, type_is_bool).
"""
self.__dict__['__is_retired_flag_func'] = is_retired_flag_func
def _parse_args(
self, args: List[str], known_only: bool
) -> Tuple[List[Tuple[Optional[str], Any]], List[str]]:
"""Helper function to do the main argument parsing.
This function goes through args and does the bulk of the flag parsing.
It will find the corresponding flag in our flag dictionary, and call its
.parse() method on the flag value.
Args:
args: [str], a list of strings with the arguments to parse.
known_only: bool, if True, parse and remove known flags; return the rest
untouched. Unknown flags specified by --undefok are not returned.
Returns:
A tuple with the following:
unknown_flags: List of (flag name, arg) for flags we don't know about.
unparsed_args: List of arguments we did not parse.
Raises:
Error: Raised on any parsing error.
ValueError: Raised on flag value parsing error.
"""
unparsed_names_and_args = [] # A list of (flag name or None, arg).
undefok = set()
retired_flag_func = self.__dict__['__is_retired_flag_func']
flag_dict = self._flags()
args = iter(args)
for arg in args:
value = None
def get_value():
# pylint: disable=cell-var-from-loop
try:
return next(args) if value is None else value
except StopIteration:
raise _exceptions.Error('Missing value for flag ' + arg) # pylint: disable=undefined-loop-variable
if not arg.startswith('-'):
# A non-argument: default is break, GNU is skip.
unparsed_names_and_args.append((None, arg))
if self.is_gnu_getopt():
continue
else:
break
if arg == '--':
if known_only:
unparsed_names_and_args.append((None, arg))
break
# At this point, arg must start with '-'.
if arg.startswith('--'):
arg_without_dashes = arg[2:]
else:
arg_without_dashes = arg[1:]
if '=' in arg_without_dashes:
name, value = arg_without_dashes.split('=', 1)
else:
name, value = arg_without_dashes, None
if not name:
# The argument is all dashes (including one dash).
unparsed_names_and_args.append((None, arg))
if self.is_gnu_getopt():
continue
else:
break
# --undefok is a special case.
if name == 'undefok':
value = get_value()
undefok.update(v.strip() for v in value.split(','))
undefok.update('no' + v.strip() for v in value.split(','))
continue
flag = flag_dict.get(name)
if flag is not None:
if flag.boolean and value is None:
value = 'true'
else:
value = get_value()
elif name.startswith('no') and len(name) > 2:
# Boolean flags can take the form of --noflag, with no value.
noflag = flag_dict.get(name[2:])
if noflag is not None and noflag.boolean:
if value is not None:
raise ValueError(arg + ' does not take an argument')
flag = noflag
value = 'false'
if retired_flag_func and flag is None:
is_retired, is_bool = retired_flag_func(name)
# If we didn't recognize that flag, but it starts with
# "no" then maybe it was a boolean flag specified in the
# --nofoo form.
if not is_retired and name.startswith('no'):
is_retired, is_bool = retired_flag_func(name[2:])
is_retired = is_retired and is_bool
if is_retired:
if not is_bool and value is None:
# This happens when a non-bool retired flag is specified
# in format of "--flag value".
get_value()
logging.error(
'Flag "%s" is retired and should no longer be specified. See '
'https://abseil.io/tips/90.',
name,
)
continue
if flag is not None:
# LINT.IfChange
flag.parse(value)
flag.using_default_value = False
# LINT.ThenChange(../testing/flagsaver.py:flag_override_parsing)
else:
unparsed_names_and_args.append((name, arg))
unknown_flags = []
unparsed_args = []
for name, arg in unparsed_names_and_args:
if name is None:
# Positional arguments.
unparsed_args.append(arg)
elif name in undefok:
# Remove undefok flags.
continue
else:
# This is an unknown flag.
if known_only:
unparsed_args.append(arg)
else:
unknown_flags.append((name, arg))
unparsed_args.extend(list(args))
return unknown_flags, unparsed_args
def is_parsed(self) -> bool:
"""Returns whether flags were parsed."""
return self.__dict__['__flags_parsed']
def mark_as_parsed(self) -> None:
"""Explicitly marks flags as parsed.
Use this when the caller knows that this FlagValues has been parsed as if
a ``__call__()`` invocation has happened. This is only a public method for
use by things like appcommands which do additional command like parsing.
"""
self.__dict__['__flags_parsed'] = True
def unparse_flags(self) -> None:
"""Unparses all flags to the point before any FLAGS(argv) was called."""
for f in self._flags().values():
f.unparse()
# We log this message before marking flags as unparsed to avoid a
# problem when the logging library causes flags access.
logging.info('unparse_flags() called; flags access will now raise errors.')
self.__dict__['__flags_parsed'] = False
self.__dict__['__unparse_flags_called'] = True
def flag_values_dict(self) -> Dict[Text, Any]:
"""Returns a dictionary that maps flag names to flag values."""
return {name: flag.value for name, flag in self._flags().items()}
def __str__(self):
"""Returns a help string for all known flags."""
return self.get_help()
def get_help(
self, prefix: Text = '', include_special_flags: bool = True
) -> Text:
"""Returns a help string for all known flags.
Args:
prefix: str, per-line output prefix.
include_special_flags: bool, whether to include description of
SPECIAL_FLAGS, i.e. --flagfile and --undefok.
Returns:
str, formatted help message.
"""
flags_by_module = self.flags_by_module_dict()
if flags_by_module:
modules = sorted(flags_by_module)
# Print the help for the main module first, if possible.
main_module = sys.argv[0]
if main_module in modules:
modules.remove(main_module)
modules = [main_module] + modules
return self._get_help_for_modules(modules, prefix, include_special_flags)
else:
output_lines = []
# Just print one long list of flags.
values = self._flags().values()
if include_special_flags:
values = itertools.chain(
values, _helpers.SPECIAL_FLAGS._flags().values() # pylint: disable=protected-access # pytype: disable=attribute-error
)
self._render_flag_list(values, output_lines, prefix)
return '\n'.join(output_lines)
def _get_help_for_modules(self, modules, prefix, include_special_flags):
"""Returns the help string for a list of modules.
Private to absl.flags package.
Args:
modules: List[str], a list of modules to get the help string for.
prefix: str, a string that is prepended to each generated help line.
include_special_flags: bool, whether to include description of
SPECIAL_FLAGS, i.e. --flagfile and --undefok.
"""
output_lines = []
for module in modules:
self._render_our_module_flags(module, output_lines, prefix)
if include_special_flags:
self._render_module_flags(
'absl.flags',
_helpers.SPECIAL_FLAGS._flags().values(), # pylint: disable=protected-access # pytype: disable=attribute-error
output_lines,
prefix,
)
return '\n'.join(output_lines)
def _render_module_flags(self, module, flags, output_lines, prefix=''):
"""Returns a help string for a given module."""
if not isinstance(module, str):
module = module.__name__
output_lines.append('\n%s%s:' % (prefix, module))
self._render_flag_list(flags, output_lines, prefix + ' ')
def _render_our_module_flags(self, module, output_lines, prefix=''):
"""Returns a help string for a given module."""
flags = self.get_flags_for_module(module)
if flags:
self._render_module_flags(module, flags, output_lines, prefix)
def _render_our_module_key_flags(self, module, output_lines, prefix=''):
"""Returns a help string for the key flags of a given module.
Args:
module: module|str, the module to render key flags for.
output_lines: [str], a list of strings. The generated help message lines
will be appended to this list.
prefix: str, a string that is prepended to each generated help line.
"""
key_flags = self.get_key_flags_for_module(module)
if key_flags:
self._render_module_flags(module, key_flags, output_lines, prefix)
def module_help(self, module: Any) -> Text:
"""Describes the key flags of a module.
Args:
module: module|str, the module to describe the key flags for.
Returns:
str, describing the key flags of a module.
"""
helplist = []
self._render_our_module_key_flags(module, helplist)
return '\n'.join(helplist)
def main_module_help(self) -> Text:
"""Describes the key flags of the main module.
Returns:
str, describing the key flags of the main module.
"""
return self.module_help(sys.argv[0])
def _render_flag_list(self, flaglist, output_lines, prefix=' '):
fl = self._flags()
special_fl = _helpers.SPECIAL_FLAGS._flags() # pylint: disable=protected-access # pytype: disable=attribute-error
flaglist = [(flag.name, flag) for flag in flaglist]
flaglist.sort()
flagset = {}
for (name, flag) in flaglist:
# It's possible this flag got deleted or overridden since being
# registered in the per-module flaglist. Check now against the
# canonical source of current flag information, the _flags.
if fl.get(name, None) != flag and special_fl.get(name, None) != flag:
# a different flag is using this name now
continue
# only print help once
if flag in flagset:
continue
flagset[flag] = 1
flaghelp = ''
if flag.short_name:
flaghelp += '-%s,' % flag.short_name
if flag.boolean:
flaghelp += '--[no]%s:' % flag.name
else:
flaghelp += '--%s:' % flag.name
flaghelp += ' '
if flag.help:
flaghelp += flag.help
flaghelp = _helpers.text_wrap(
flaghelp, indent=prefix + ' ', firstline_indent=prefix)
if flag.default_as_str:
flaghelp += '\n'
flaghelp += _helpers.text_wrap(
'(default: %s)' % flag.default_as_str, indent=prefix + ' ')
if flag.parser.syntactic_help:
flaghelp += '\n'
flaghelp += _helpers.text_wrap(
'(%s)' % flag.parser.syntactic_help, indent=prefix + ' ')
output_lines.append(flaghelp)
def get_flag_value(self, name: Text, default: Any) -> Any: # pylint: disable=invalid-name
"""Returns the value of a flag (if not None) or a default value.
Args:
name: str, the name of a flag.
default: Default value to use if the flag value is None.
Returns:
Requested flag value or default.
"""
value = self.__getattr__(name)
if value is not None: # Can't do if not value, b/c value might be '0' or ""
return value
else:
return default
def _is_flag_file_directive(self, flag_string):
"""Checks whether flag_string contain a --flagfile= directive."""
if isinstance(flag_string, str):
if flag_string.startswith('--flagfile='):
return 1
elif flag_string == '--flagfile':
return 1
elif flag_string.startswith('-flagfile='):
return 1
elif flag_string == '-flagfile':
return 1
else:
return 0
return 0
def _extract_filename(self, flagfile_str):
"""Returns filename from a flagfile_str of form -[-]flagfile=filename.
The cases of --flagfile foo and -flagfile foo shouldn't be hitting
this function, as they are dealt with in the level above this
function.
Args:
flagfile_str: str, the flagfile string.
Returns:
str, the filename from a flagfile_str of form -[-]flagfile=filename.
Raises:
Error: Raised when illegal --flagfile is provided.
"""
if flagfile_str.startswith('--flagfile='):
return os.path.expanduser((flagfile_str[(len('--flagfile=')):]).strip())
elif flagfile_str.startswith('-flagfile='):
return os.path.expanduser((flagfile_str[(len('-flagfile=')):]).strip())
else:
raise _exceptions.Error('Hit illegal --flagfile type: %s' % flagfile_str)
def _get_flag_file_lines(self, filename, parsed_file_stack=None):
"""Returns the useful (!=comments, etc) lines from a file with flags.
Args:
filename: str, the name of the flag file.
parsed_file_stack: [str], a list of the names of the files that we have
recursively encountered at the current depth. MUTATED BY THIS FUNCTION
(but the original value is preserved upon successfully returning from
function call).
Returns:
List of strings. See the note below.
NOTE(springer): This function checks for a nested --flagfile=
tag and handles the lower file recursively. It returns a list of
all the lines that _could_ contain command flags. This is
EVERYTHING except whitespace lines and comments (lines starting
with '#' or '//').
"""
# For consistency with the cpp version, ignore empty values.
if not filename:
return []
if parsed_file_stack is None:
parsed_file_stack = []
# We do a little safety check for reparsing a file we've already encountered
# at a previous depth.
if filename in parsed_file_stack:
sys.stderr.write('Warning: Hit circular flagfile dependency. Ignoring'
' flagfile: %s\n' % (filename,))
return []
else:
parsed_file_stack.append(filename)
line_list = [] # All line from flagfile.
flag_line_list = [] # Subset of lines w/o comments, blanks, flagfile= tags.
try:
file_obj = open(filename, 'r')
except IOError as e_msg:
raise _exceptions.CantOpenFlagFileError(
'ERROR:: Unable to open flagfile: %s' % e_msg)
with file_obj:
line_list = file_obj.readlines()
# This is where we check each line in the file we just read.
for line in line_list:
if line.isspace():
pass
# Checks for comment (a line that starts with '#').
elif line.startswith('#') or line.startswith('//'):
pass
# Checks for a nested "--flagfile=" flag in the current file.
# If we find one, recursively parse down into that file.
elif self._is_flag_file_directive(line):
sub_filename = self._extract_filename(line)
included_flags = self._get_flag_file_lines(
sub_filename, parsed_file_stack=parsed_file_stack)
flag_line_list.extend(included_flags)
else:
# Any line that's not a comment or a nested flagfile should get
# copied into 2nd position. This leaves earlier arguments
# further back in the list, thus giving them higher priority.
flag_line_list.append(line.strip())
parsed_file_stack.pop()
return flag_line_list
def read_flags_from_files(
self, argv: Sequence[Text], force_gnu: bool = True
) -> List[Text]:
"""Processes command line args, but also allow args to be read from file.
Args:
argv: [str], a list of strings, usually sys.argv[1:], which may contain
one or more flagfile directives of the form --flagfile="./filename".
Note that the name of the program (sys.argv[0]) should be omitted.
force_gnu: bool, if False, --flagfile parsing obeys the
FLAGS.is_gnu_getopt() value. If True, ignore the value and always follow
gnu_getopt semantics.
Returns:
A new list which has the original list combined with what we read
from any flagfile(s).
Raises:
IllegalFlagValueError: Raised when --flagfile is provided with no
argument.
This function is called by FLAGS(argv).
It scans the input list for a flag that looks like:
--flagfile=. Then it opens , reads all valid key
and value pairs and inserts them into the input list in exactly the
place where the --flagfile arg is found.
Note that your application's flags are still defined the usual way
using absl.flags DEFINE_flag() type functions.
Notes (assuming we're getting a commandline of some sort as our input):
* For duplicate flags, the last one we hit should "win".
* Since flags that appear later win, a flagfile's settings can be "weak"
if the --flagfile comes at the beginning of the argument sequence,
and it can be "strong" if the --flagfile comes at the end.
* A further "--flagfile=" CAN be nested in a flagfile.
It will be expanded in exactly the spot where it is found.
* In a flagfile, a line beginning with # or // is a comment.
* Entirely blank lines _should_ be ignored.
"""
rest_of_args = argv
new_argv = []
while rest_of_args:
current_arg = rest_of_args[0]
rest_of_args = rest_of_args[1:]
if self._is_flag_file_directive(current_arg):
# This handles the case of -(-)flagfile foo. In this case the
# next arg really is part of this one.
if current_arg == '--flagfile' or current_arg == '-flagfile':
if not rest_of_args:
raise _exceptions.IllegalFlagValueError(
'--flagfile with no argument')
flag_filename = os.path.expanduser(rest_of_args[0])
rest_of_args = rest_of_args[1:]
else:
# This handles the case of (-)-flagfile=foo.
flag_filename = self._extract_filename(current_arg)
new_argv.extend(self._get_flag_file_lines(flag_filename))
else:
new_argv.append(current_arg)
# Stop parsing after '--', like getopt and gnu_getopt.
if current_arg == '--':
break
# Stop parsing after a non-flag, like getopt.
if not current_arg.startswith('-'):
if not force_gnu and not self.__dict__['__use_gnu_getopt']:
break
else:
if ('=' not in current_arg and rest_of_args and
not rest_of_args[0].startswith('-')):
# If this is an occurrence of a legitimate --x y, skip the value
# so that it won't be mistaken for a standalone arg.
fl = self._flags()
name = current_arg.lstrip('-')
if name in fl and not fl[name].boolean:
current_arg = rest_of_args[0]
rest_of_args = rest_of_args[1:]
new_argv.append(current_arg)
if rest_of_args:
new_argv.extend(rest_of_args)
return new_argv
def flags_into_string(self) -> Text:
"""Returns a string with the flags assignments from this FlagValues object.
This function ignores flags whose value is None. Each flag
assignment is separated by a newline.
NOTE: MUST mirror the behavior of the C++ CommandlineFlagsIntoString
from https://github.com/gflags/gflags.
Returns:
str, the string with the flags assignments from this FlagValues object.
The flags are ordered by (module_name, flag_name).
"""
module_flags = sorted(self.flags_by_module_dict().items())
s = ''
for unused_module_name, flags in module_flags:
flags = sorted(flags, key=lambda f: f.name)
for flag in flags:
if flag.value is not None:
s += flag.serialize() + '\n'
return s
def append_flags_into_file(self, filename: Text) -> None:
"""Appends all flags assignments from this FlagInfo object to a file.
Output will be in the format of a flagfile.
NOTE: MUST mirror the behavior of the C++ AppendFlagsIntoFile
from https://github.com/gflags/gflags.
Args:
filename: str, name of the file.
"""
with open(filename, 'a') as out_file:
out_file.write(self.flags_into_string())
def write_help_in_xml_format(self, outfile: Optional[TextIO] = None) -> None:
"""Outputs flag documentation in XML format.
NOTE: We use element names that are consistent with those used by
the C++ command-line flag library, from
https://github.com/gflags/gflags.
We also use a few new elements (e.g., ), but we do not
interfere / overlap with existing XML elements used by the C++
library. Please maintain this consistency.
Args:
outfile: File object we write to. Default None means sys.stdout.
"""
doc = minidom.Document()
all_flag = doc.createElement('AllFlags')
doc.appendChild(all_flag)
all_flag.appendChild(
_helpers.create_xml_dom_element(doc, 'program',
os.path.basename(sys.argv[0])))
usage_doc = sys.modules['__main__'].__doc__
if not usage_doc:
usage_doc = '\nUSAGE: %s [flags]\n' % sys.argv[0]
else:
usage_doc = usage_doc.replace('%s', sys.argv[0])
all_flag.appendChild(
_helpers.create_xml_dom_element(doc, 'usage', usage_doc))
# Get list of key flags for the main module.
key_flags = self.get_key_flags_for_module(sys.argv[0])
# Sort flags by declaring module name and next by flag name.
flags_by_module = self.flags_by_module_dict()
all_module_names = list(flags_by_module.keys())
all_module_names.sort()
for module_name in all_module_names:
flag_list = [(f.name, f) for f in flags_by_module[module_name]]
flag_list.sort()
for unused_flag_name, flag in flag_list:
is_key = flag in key_flags
all_flag.appendChild(
flag._create_xml_dom_element( # pylint: disable=protected-access
doc,
module_name,
is_key=is_key))
outfile = outfile or sys.stdout
outfile.write(
doc.toprettyxml(indent=' ', encoding='utf-8').decode('utf-8'))
outfile.flush()
def _check_method_name_conflicts(self, name: str, flag: Flag):
if flag.allow_using_method_names:
return
short_name = flag.short_name
flag_names = {name} if short_name is None else {name, short_name}
for flag_name in flag_names:
if flag_name in self.__dict__['__banned_flag_names']:
raise _exceptions.FlagNameConflictsWithMethodError(
'Cannot define a flag named "{name}". It conflicts with a method '
'on class "{class_name}". To allow defining it, use '
'allow_using_method_names and access the flag value with '
"FLAGS['{name}'].value. FLAGS.{name} returns the method, "
'not the flag value.'.format(
name=flag_name, class_name=type(self).__name__))
FLAGS = FlagValues()
class FlagHolder(Generic[_T]):
"""Holds a defined flag.
This facilitates a cleaner api around global state. Instead of::
flags.DEFINE_integer('foo', ...)
flags.DEFINE_integer('bar', ...)
def method():
# prints parsed value of 'bar' flag
print(flags.FLAGS.foo)
# runtime error due to typo or possibly bad coding style.
print(flags.FLAGS.baz)
it encourages code like::
_FOO_FLAG = flags.DEFINE_integer('foo', ...)
_BAR_FLAG = flags.DEFINE_integer('bar', ...)
def method():
print(_FOO_FLAG.value)
print(_BAR_FLAG.value)
since the name of the flag appears only once in the source code.
"""
value: _T
def __init__(
self,
flag_values: FlagValues,
flag: Flag[_T],
ensure_non_none_value: bool = False,
):
"""Constructs a FlagHolder instance providing typesafe access to flag.
Args:
flag_values: The container the flag is registered to.
flag: The flag object for this flag.
ensure_non_none_value: Is the value of the flag allowed to be None.
"""
self._flagvalues = flag_values
# We take the entire flag object, but only keep the name. Why?
# - We want FlagHolder[T] to be generic container
# - flag_values contains all flags, so has no reference to T.
# - typecheckers don't like to see a generic class where none of the ctor
# arguments refer to the generic type.
self._name = flag.name
# We intentionally do NOT check if the default value is None.
# This allows future use of this for "required flags with None default"
self._ensure_non_none_value = ensure_non_none_value
def __eq__(self, other):
raise TypeError(
"unsupported operand type(s) for ==: '{0}' and '{1}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__, type(other).__name__))
def __bool__(self):
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))
__nonzero__ = __bool__
@property
def name(self) -> Text:
return self._name
@property
def value(self) -> _T:
"""Returns the value of the flag.
If ``_ensure_non_none_value`` is ``True``, then return value is not
``None``.
Raises:
UnparsedFlagAccessError: if flag parsing has not finished.
IllegalFlagValueError: if value is None unexpectedly.
"""
val = getattr(self._flagvalues, self._name)
if self._ensure_non_none_value and val is None:
raise _exceptions.IllegalFlagValueError(
'Unexpected None value for flag %s' % self._name)
return val
@property
def default(self) -> _T:
"""Returns the default value of the flag."""
return self._flagvalues[self._name].default
@property
def present(self) -> bool:
"""Returns True if the flag was parsed from command-line flags."""
return bool(self._flagvalues[self._name].present)
def serialize(self) -> Text:
"""Returns a serialized representation of the flag."""
return self._flagvalues[self._name].serialize()
def resolve_flag_ref(
flag_ref: Union[str, FlagHolder], flag_values: FlagValues
) -> Tuple[str, FlagValues]:
"""Helper to validate and resolve a flag reference argument."""
if isinstance(flag_ref, FlagHolder):
new_flag_values = flag_ref._flagvalues # pylint: disable=protected-access
if flag_values != FLAGS and flag_values != new_flag_values:
raise ValueError(
'flag_values must not be customized when operating on a FlagHolder')
return flag_ref.name, new_flag_values
return flag_ref, flag_values
def resolve_flag_refs(
flag_refs: Sequence[Union[str, FlagHolder]], flag_values: FlagValues
) -> Tuple[List[str], FlagValues]:
"""Helper to validate and resolve flag reference list arguments."""
fv = None
names = []
for ref in flag_refs:
if isinstance(ref, FlagHolder):
newfv = ref._flagvalues # pylint: disable=protected-access
name = ref.name
else:
newfv = flag_values
name = ref
if fv and fv != newfv:
raise ValueError(
'multiple FlagValues instances used in invocation. '
'FlagHolders must be registered to the same FlagValues instance as '
'do flag names, if provided.')
fv = newfv
names.append(name)
return names, fv
abseil-py-2.1.0/absl/flags/_helpers.py 0000664 0000000 0000000 00000033401 14551576331 0017602 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal helper functions for Abseil Python flags library."""
import os
import re
import struct
import sys
import textwrap
import types
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Set
from xml.dom import minidom
# pylint: disable=g-import-not-at-top
try:
import fcntl
except ImportError:
fcntl = None
try:
# Importing termios will fail on non-unix platforms.
import termios
except ImportError:
termios = None
# pylint: enable=g-import-not-at-top
_DEFAULT_HELP_WIDTH = 80 # Default width of help output.
# Minimal "sane" width of help output. We assume that any value below 40 is
# unreasonable.
_MIN_HELP_WIDTH = 40
# Define the allowed error rate in an input string to get suggestions.
#
# We lean towards a high threshold because we tend to be matching a phrase,
# and the simple algorithm used here is geared towards correcting word
# spellings.
#
# For manual testing, consider " --list" which produced a large number
# of spurious suggestions when we used "least_errors > 0.5" instead of
# "least_erros >= 0.5".
_SUGGESTION_ERROR_RATE_THRESHOLD = 0.50
# Characters that cannot appear or are highly discouraged in an XML 1.0
# document. (See http://www.w3.org/TR/REC-xml/#charsets or
# https://en.wikipedia.org/wiki/Valid_characters_in_XML#XML_1.0)
_ILLEGAL_XML_CHARS_REGEX = re.compile(
u'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x84\x86-\x9f\ud800-\udfff\ufffe\uffff]')
# This is a set of module ids for the modules that disclaim key flags.
# This module is explicitly added to this set so that we never consider it to
# define key flag.
disclaim_module_ids: Set[int] = set([id(sys.modules[__name__])])
# Define special flags here so that help may be generated for them.
# NOTE: Please do NOT use SPECIAL_FLAGS from outside flags module.
# Initialized inside flagvalues.py.
# NOTE: This cannot be annotated as its actual FlagValues type since this would
# create a circular dependency.
SPECIAL_FLAGS: Any = None
# This points to the flags module, initialized in flags/__init__.py.
# This should only be used in adopt_module_key_flags to take SPECIAL_FLAGS into
# account.
FLAGS_MODULE: types.ModuleType = None
class _ModuleObjectAndName(NamedTuple):
"""Module object and name.
Fields:
- module: object, module object.
- module_name: str, module name.
"""
module: types.ModuleType
module_name: str
def get_module_object_and_name(
globals_dict: Dict[str, Any]
) -> _ModuleObjectAndName:
"""Returns the module that defines a global environment, and its name.
Args:
globals_dict: A dictionary that should correspond to an environment
providing the values of the globals.
Returns:
_ModuleObjectAndName - pair of module object & module name.
Returns (None, None) if the module could not be identified.
"""
name = globals_dict.get('__name__', None)
module = sys.modules.get(name, None)
# Pick a more informative name for the main module.
return _ModuleObjectAndName(module,
(sys.argv[0] if name == '__main__' else name))
def get_calling_module_object_and_name() -> _ModuleObjectAndName:
"""Returns the module that's calling into this module.
We generally use this function to get the name of the module calling a
DEFINE_foo... function.
Returns:
The module object that called into this one.
Raises:
AssertionError: Raised when no calling module could be identified.
"""
for depth in range(1, sys.getrecursionlimit()):
# sys._getframe is the right thing to use here, as it's the best
# way to walk up the call stack.
globals_for_frame = sys._getframe(depth).f_globals # pylint: disable=protected-access
module, module_name = get_module_object_and_name(globals_for_frame)
if id(module) not in disclaim_module_ids and module_name is not None:
return _ModuleObjectAndName(module, module_name)
raise AssertionError('No module was found')
def get_calling_module() -> str:
"""Returns the name of the module that's calling into this module."""
return get_calling_module_object_and_name().module_name
def create_xml_dom_element(
doc: minidom.Document, name: str, value: Any
) -> minidom.Element:
"""Returns an XML DOM element with name and text value.
Args:
doc: minidom.Document, the DOM document it should create nodes from.
name: str, the tag of XML element.
value: object, whose string representation will be used
as the value of the XML element. Illegal or highly discouraged xml 1.0
characters are stripped.
Returns:
An instance of minidom.Element.
"""
s = str(value)
if isinstance(value, bool):
# Display boolean values as the C++ flag library does: no caps.
s = s.lower()
# Remove illegal xml characters.
s = _ILLEGAL_XML_CHARS_REGEX.sub(u'', s)
e = doc.createElement(name)
e.appendChild(doc.createTextNode(s))
return e
def get_help_width() -> int:
"""Returns the integer width of help lines that is used in TextWrap."""
if not sys.stdout.isatty() or termios is None or fcntl is None:
return _DEFAULT_HELP_WIDTH
try:
data = fcntl.ioctl(sys.stdout, termios.TIOCGWINSZ, b'1234')
columns = struct.unpack('hh', data)[1]
# Emacs mode returns 0.
# Here we assume that any value below 40 is unreasonable.
if columns >= _MIN_HELP_WIDTH:
return columns
# Returning an int as default is fine, int(int) just return the int.
return int(os.getenv('COLUMNS', _DEFAULT_HELP_WIDTH))
except (TypeError, IOError, struct.error):
return _DEFAULT_HELP_WIDTH
def get_flag_suggestions(
attempt: Optional[str], longopt_list: Sequence[str]
) -> List[str]:
"""Returns helpful similar matches for an invalid flag."""
# Don't suggest on very short strings, or if no longopts are specified.
if len(attempt) <= 2 or not longopt_list:
return []
option_names = [v.split('=')[0] for v in longopt_list]
# Find close approximations in flag prefixes.
# This also handles the case where the flag is spelled right but ambiguous.
distances = [(_damerau_levenshtein(attempt, option[0:len(attempt)]), option)
for option in option_names]
# t[0] is distance, and sorting by t[1] allows us to have stable output.
distances.sort()
least_errors, _ = distances[0]
# Don't suggest excessively bad matches.
if least_errors >= _SUGGESTION_ERROR_RATE_THRESHOLD * len(attempt):
return []
suggestions = []
for errors, name in distances:
if errors == least_errors:
suggestions.append(name)
else:
break
return suggestions
def _damerau_levenshtein(a, b):
"""Returns Damerau-Levenshtein edit distance from a to b."""
memo = {}
def distance(x, y):
"""Recursively defined string distance with memoization."""
if (x, y) in memo:
return memo[x, y]
if not x:
d = len(y)
elif not y:
d = len(x)
else:
d = min(
distance(x[1:], y) + 1, # correct an insertion error
distance(x, y[1:]) + 1, # correct a deletion error
distance(x[1:], y[1:]) + (x[0] != y[0])) # correct a wrong character
if len(x) >= 2 and len(y) >= 2 and x[0] == y[1] and x[1] == y[0]:
# Correct a transposition.
t = distance(x[2:], y[2:]) + 1
if d > t:
d = t
memo[x, y] = d
return d
return distance(a, b)
def text_wrap(
text: str,
length: Optional[int] = None,
indent: str = '',
firstline_indent: Optional[str] = None,
) -> str:
"""Wraps a given text to a maximum line length and returns it.
It turns lines that only contain whitespace into empty lines, keeps new lines,
and expands tabs using 4 spaces.
Args:
text: str, text to wrap.
length: int, maximum length of a line, includes indentation.
If this is None then use get_help_width()
indent: str, indent for all but first line.
firstline_indent: str, indent for first line; if None, fall back to indent.
Returns:
str, the wrapped text.
Raises:
ValueError: Raised if indent or firstline_indent not shorter than length.
"""
# Get defaults where callee used None
if length is None:
length = get_help_width()
if indent is None:
indent = ''
if firstline_indent is None:
firstline_indent = indent
if len(indent) >= length:
raise ValueError('Length of indent exceeds length')
if len(firstline_indent) >= length:
raise ValueError('Length of first line indent exceeds length')
text = text.expandtabs(4)
result = []
# Create one wrapper for the first paragraph and one for subsequent
# paragraphs that does not have the initial wrapping.
wrapper = textwrap.TextWrapper(
width=length, initial_indent=firstline_indent, subsequent_indent=indent)
subsequent_wrapper = textwrap.TextWrapper(
width=length, initial_indent=indent, subsequent_indent=indent)
# textwrap does not have any special treatment for newlines. From the docs:
# "...newlines may appear in the middle of a line and cause strange output.
# For this reason, text should be split into paragraphs (using
# str.splitlines() or similar) which are wrapped separately."
for paragraph in (p.strip() for p in text.splitlines()):
if paragraph:
result.extend(wrapper.wrap(paragraph))
else:
result.append('') # Keep empty lines.
# Replace initial wrapper with wrapper for subsequent paragraphs.
wrapper = subsequent_wrapper
return '\n'.join(result)
def flag_dict_to_args(
flag_map: Dict[str, Any], multi_flags: Optional[Set[str]] = None
) -> Iterable[str]:
"""Convert a dict of values into process call parameters.
This method is used to convert a dictionary into a sequence of parameters
for a binary that parses arguments using this module.
Args:
flag_map: dict, a mapping where the keys are flag names (strings).
values are treated according to their type:
* If value is ``None``, then only the name is emitted.
* If value is ``True``, then only the name is emitted.
* If value is ``False``, then only the name prepended with 'no' is
emitted.
* If value is a string then ``--name=value`` is emitted.
* If value is a collection, this will emit
``--name=value1,value2,value3``, unless the flag name is in
``multi_flags``, in which case this will emit
``--name=value1 --name=value2 --name=value3``.
* Everything else is converted to string an passed as such.
multi_flags: set, names (strings) of flags that should be treated as
multi-flags.
Yields:
sequence of string suitable for a subprocess execution.
"""
for key, value in flag_map.items():
if value is None:
yield '--%s' % key
elif isinstance(value, bool):
if value:
yield '--%s' % key
else:
yield '--no%s' % key
elif isinstance(value, (bytes, type(u''))):
# We don't want strings to be handled like python collections.
yield '--%s=%s' % (key, value)
else:
# Now we attempt to deal with collections.
try:
if multi_flags and key in multi_flags:
for item in value:
yield '--%s=%s' % (key, str(item))
else:
yield '--%s=%s' % (key, ','.join(str(item) for item in value))
except TypeError:
# Default case.
yield '--%s=%s' % (key, value)
def trim_docstring(docstring: str) -> str:
"""Removes indentation from triple-quoted strings.
This is the function specified in PEP 257 to handle docstrings:
https://www.python.org/dev/peps/pep-0257/.
Args:
docstring: str, a python docstring.
Returns:
str, docstring with indentation removed.
"""
if not docstring:
return ''
# If you've got a line longer than this you have other problems...
max_indent = 1 << 29
# Convert tabs to spaces (following the normal Python rules)
# and split into a list of lines:
lines = docstring.expandtabs().splitlines()
# Determine minimum indentation (first line doesn't count):
indent = max_indent
for line in lines[1:]:
stripped = line.lstrip()
if stripped:
indent = min(indent, len(line) - len(stripped))
# Remove indentation (first line is special):
trimmed = [lines[0].strip()]
if indent < max_indent:
for line in lines[1:]:
trimmed.append(line[indent:].rstrip())
# Strip off trailing and leading blank lines:
while trimmed and not trimmed[-1]:
trimmed.pop()
while trimmed and not trimmed[0]:
trimmed.pop(0)
# Return a single string:
return '\n'.join(trimmed)
def doc_to_help(doc: str) -> str:
"""Takes a __doc__ string and reformats it as help."""
# Get rid of starting and ending white space. Using lstrip() or even
# strip() could drop more than maximum of first line and right space
# of last line.
doc = doc.strip()
# Get rid of all empty lines.
whitespace_only_line = re.compile('^[ \t]+$', re.M)
doc = whitespace_only_line.sub('', doc)
# Cut out common space at line beginnings.
doc = trim_docstring(doc)
# Just like this module's comment, comments tend to be aligned somehow.
# In other words they all start with the same amount of white space.
# 1) keep double new lines;
# 2) keep ws after new lines if not empty line;
# 3) all other new lines shall be changed to a space;
# Solution: Match new lines between non white space and replace with space.
doc = re.sub(r'(?<=\S)\n(?=\S)', ' ', doc, flags=re.M)
return doc
abseil-py-2.1.0/absl/flags/_validators.py 0000664 0000000 0000000 00000033500 14551576331 0020310 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module to enforce different constraints on flags.
Flags validators can be registered using following functions / decorators::
flags.register_validator
@flags.validator
flags.register_multi_flags_validator
@flags.multi_flags_validator
Three convenience functions are also provided for common flag constraints::
flags.mark_flag_as_required
flags.mark_flags_as_required
flags.mark_flags_as_mutual_exclusive
flags.mark_bool_flags_as_mutual_exclusive
See their docstring in this module for a usage manual.
Do NOT import this module directly. Import the flags package and use the
aliases defined at the package level instead.
"""
import warnings
from absl.flags import _exceptions
from absl.flags import _flagvalues
from absl.flags import _validators_classes
def register_validator(flag_name,
checker,
message='Flag validation failed',
flag_values=_flagvalues.FLAGS):
"""Adds a constraint, which will be enforced during program execution.
The constraint is validated when flags are initially parsed, and after each
change of the corresponding flag's value.
Args:
flag_name: str | FlagHolder, name or holder of the flag to be checked.
Positional-only parameter.
checker: callable, a function to validate the flag.
* input - A single positional argument: The value of the corresponding
flag (string, boolean, etc. This value will be passed to checker
by the library).
* output - bool, True if validator constraint is satisfied.
If constraint is not satisfied, it should either ``return False`` or
``raise flags.ValidationError(desired_error_message)``.
message: str, error text to be shown to the user if checker returns False.
If checker raises flags.ValidationError, message from the raised
error will be shown.
flag_values: flags.FlagValues, optional FlagValues instance to validate
against.
Raises:
AttributeError: Raised when flag_name is not registered as a valid flag
name.
ValueError: Raised when flag_values is non-default and does not match the
FlagValues of the provided FlagHolder instance.
"""
flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values)
v = _validators_classes.SingleFlagValidator(flag_name, checker, message)
_add_validator(flag_values, v)
def validator(flag_name, message='Flag validation failed',
flag_values=_flagvalues.FLAGS):
"""A function decorator for defining a flag validator.
Registers the decorated function as a validator for flag_name, e.g.::
@flags.validator('foo')
def _CheckFoo(foo):
...
See :func:`register_validator` for the specification of checker function.
Args:
flag_name: str | FlagHolder, name or holder of the flag to be checked.
Positional-only parameter.
message: str, error text to be shown to the user if checker returns False.
If checker raises flags.ValidationError, message from the raised
error will be shown.
flag_values: flags.FlagValues, optional FlagValues instance to validate
against.
Returns:
A function decorator that registers its function argument as a validator.
Raises:
AttributeError: Raised when flag_name is not registered as a valid flag
name.
"""
def decorate(function):
register_validator(flag_name, function,
message=message,
flag_values=flag_values)
return function
return decorate
def register_multi_flags_validator(flag_names,
multi_flags_checker,
message='Flags validation failed',
flag_values=_flagvalues.FLAGS):
"""Adds a constraint to multiple flags.
The constraint is validated when flags are initially parsed, and after each
change of the corresponding flag's value.
Args:
flag_names: [str | FlagHolder], a list of the flag names or holders to be
checked. Positional-only parameter.
multi_flags_checker: callable, a function to validate the flag.
* input - dict, with keys() being flag_names, and value for each key
being the value of the corresponding flag (string, boolean, etc).
* output - bool, True if validator constraint is satisfied.
If constraint is not satisfied, it should either return False or
raise flags.ValidationError.
message: str, error text to be shown to the user if checker returns False.
If checker raises flags.ValidationError, message from the raised
error will be shown.
flag_values: flags.FlagValues, optional FlagValues instance to validate
against.
Raises:
AttributeError: Raised when a flag is not registered as a valid flag name.
ValueError: Raised when multiple FlagValues are used in the same
invocation. This can occur when FlagHolders have different `_flagvalues`
or when str-type flag_names entries are present and the `flag_values`
argument does not match that of provided FlagHolder(s).
"""
flag_names, flag_values = _flagvalues.resolve_flag_refs(
flag_names, flag_values)
v = _validators_classes.MultiFlagsValidator(
flag_names, multi_flags_checker, message)
_add_validator(flag_values, v)
def multi_flags_validator(flag_names,
message='Flag validation failed',
flag_values=_flagvalues.FLAGS):
"""A function decorator for defining a multi-flag validator.
Registers the decorated function as a validator for flag_names, e.g.::
@flags.multi_flags_validator(['foo', 'bar'])
def _CheckFooBar(flags_dict):
...
See :func:`register_multi_flags_validator` for the specification of checker
function.
Args:
flag_names: [str | FlagHolder], a list of the flag names or holders to be
checked. Positional-only parameter.
message: str, error text to be shown to the user if checker returns False.
If checker raises flags.ValidationError, message from the raised
error will be shown.
flag_values: flags.FlagValues, optional FlagValues instance to validate
against.
Returns:
A function decorator that registers its function argument as a validator.
Raises:
AttributeError: Raised when a flag is not registered as a valid flag name.
"""
def decorate(function):
register_multi_flags_validator(flag_names,
function,
message=message,
flag_values=flag_values)
return function
return decorate
def mark_flag_as_required(flag_name, flag_values=_flagvalues.FLAGS):
"""Ensures that flag is not None during program execution.
Registers a flag validator, which will follow usual validator rules.
Important note: validator will pass for any non-``None`` value, such as
``False``, ``0`` (zero), ``''`` (empty string) and so on.
If your module might be imported by others, and you only wish to make the flag
required when the module is directly executed, call this method like this::
if __name__ == '__main__':
flags.mark_flag_as_required('your_flag_name')
app.run()
Args:
flag_name: str | FlagHolder, name or holder of the flag.
Positional-only parameter.
flag_values: flags.FlagValues, optional :class:`~absl.flags.FlagValues`
instance where the flag is defined.
Raises:
AttributeError: Raised when flag_name is not registered as a valid flag
name.
ValueError: Raised when flag_values is non-default and does not match the
FlagValues of the provided FlagHolder instance.
"""
flag_name, flag_values = _flagvalues.resolve_flag_ref(flag_name, flag_values)
if flag_values[flag_name].default is not None:
warnings.warn(
'Flag --%s has a non-None default value; therefore, '
'mark_flag_as_required will pass even if flag is not specified in the '
'command line!' % flag_name,
stacklevel=2)
register_validator(
flag_name,
lambda value: value is not None,
message='Flag --{} must have a value other than None.'.format(flag_name),
flag_values=flag_values)
def mark_flags_as_required(flag_names, flag_values=_flagvalues.FLAGS):
"""Ensures that flags are not None during program execution.
If your module might be imported by others, and you only wish to make the flag
required when the module is directly executed, call this method like this::
if __name__ == '__main__':
flags.mark_flags_as_required(['flag1', 'flag2', 'flag3'])
app.run()
Args:
flag_names: Sequence[str | FlagHolder], names or holders of the flags.
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
Raises:
AttributeError: If any of flag name has not already been defined as a flag.
"""
for flag_name in flag_names:
mark_flag_as_required(flag_name, flag_values)
def mark_flags_as_mutual_exclusive(flag_names, required=False,
flag_values=_flagvalues.FLAGS):
"""Ensures that only one flag among flag_names is not None.
Important note: This validator checks if flag values are ``None``, and it does
not distinguish between default and explicit values. Therefore, this validator
does not make sense when applied to flags with default values other than None,
including other false values (e.g. ``False``, ``0``, ``''``, ``[]``). That
includes multi flags with a default value of ``[]`` instead of None.
Args:
flag_names: [str | FlagHolder], names or holders of flags.
Positional-only parameter.
required: bool. If true, exactly one of the flags must have a value other
than None. Otherwise, at most one of the flags can have a value other
than None, and it is valid for all of the flags to be None.
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
Raises:
ValueError: Raised when multiple FlagValues are used in the same
invocation. This can occur when FlagHolders have different `_flagvalues`
or when str-type flag_names entries are present and the `flag_values`
argument does not match that of provided FlagHolder(s).
"""
flag_names, flag_values = _flagvalues.resolve_flag_refs(
flag_names, flag_values)
for flag_name in flag_names:
if flag_values[flag_name].default is not None:
warnings.warn(
'Flag --{} has a non-None default value. That does not make sense '
'with mark_flags_as_mutual_exclusive, which checks whether the '
'listed flags have a value other than None.'.format(flag_name),
stacklevel=2)
def validate_mutual_exclusion(flags_dict):
flag_count = sum(1 for val in flags_dict.values() if val is not None)
if flag_count == 1 or (not required and flag_count == 0):
return True
raise _exceptions.ValidationError(
'{} one of ({}) must have a value other than None.'.format(
'Exactly' if required else 'At most', ', '.join(flag_names)))
register_multi_flags_validator(
flag_names, validate_mutual_exclusion, flag_values=flag_values)
def mark_bool_flags_as_mutual_exclusive(flag_names, required=False,
flag_values=_flagvalues.FLAGS):
"""Ensures that only one flag among flag_names is True.
Args:
flag_names: [str | FlagHolder], names or holders of flags.
Positional-only parameter.
required: bool. If true, exactly one flag must be True. Otherwise, at most
one flag can be True, and it is valid for all flags to be False.
flag_values: flags.FlagValues, optional FlagValues instance where the flags
are defined.
Raises:
ValueError: Raised when multiple FlagValues are used in the same
invocation. This can occur when FlagHolders have different `_flagvalues`
or when str-type flag_names entries are present and the `flag_values`
argument does not match that of provided FlagHolder(s).
"""
flag_names, flag_values = _flagvalues.resolve_flag_refs(
flag_names, flag_values)
for flag_name in flag_names:
if not flag_values[flag_name].boolean:
raise _exceptions.ValidationError(
'Flag --{} is not Boolean, which is required for flags used in '
'mark_bool_flags_as_mutual_exclusive.'.format(flag_name))
def validate_boolean_mutual_exclusion(flags_dict):
flag_count = sum(bool(val) for val in flags_dict.values())
if flag_count == 1 or (not required and flag_count == 0):
return True
raise _exceptions.ValidationError(
'{} one of ({}) must be True.'.format(
'Exactly' if required else 'At most', ', '.join(flag_names)))
register_multi_flags_validator(
flag_names, validate_boolean_mutual_exclusion, flag_values=flag_values)
def _add_validator(fv, validator_instance):
"""Register new flags validator to be checked.
Args:
fv: flags.FlagValues, the FlagValues instance to add the validator.
validator_instance: validators.Validator, the validator to add.
Raises:
KeyError: Raised when validators work with a non-existing flag.
"""
for flag_name in validator_instance.get_flags_names():
fv[flag_name].validators.append(validator_instance)
abseil-py-2.1.0/absl/flags/_validators_classes.py 0000664 0000000 0000000 00000014015 14551576331 0022025 0 ustar 00root root 0000000 0000000 # Copyright 2021 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines *private* classes used for flag validators.
Do NOT import this module. DO NOT use anything from this module. They are
private APIs.
"""
from absl.flags import _exceptions
class Validator(object):
"""Base class for flags validators.
Users should NOT overload these classes, and use flags.Register...
methods instead.
"""
# Used to assign each validator an unique insertion_index
validators_count = 0
def __init__(self, checker, message):
"""Constructor to create all validators.
Args:
checker: function to verify the constraint.
Input of this method varies, see SingleFlagValidator and
multi_flags_validator for a detailed description.
message: str, error message to be shown to the user.
"""
self.checker = checker
self.message = message
Validator.validators_count += 1
# Used to assert validators in the order they were registered.
self.insertion_index = Validator.validators_count
def verify(self, flag_values):
"""Verifies that constraint is satisfied.
flags library calls this method to verify Validator's constraint.
Args:
flag_values: flags.FlagValues, the FlagValues instance to get flags from.
Raises:
Error: Raised if constraint is not satisfied.
"""
param = self._get_input_to_checker_function(flag_values)
if not self.checker(param):
raise _exceptions.ValidationError(self.message)
def get_flags_names(self):
"""Returns the names of the flags checked by this validator.
Returns:
[string], names of the flags.
"""
raise NotImplementedError('This method should be overloaded')
def print_flags_with_values(self, flag_values):
raise NotImplementedError('This method should be overloaded')
def _get_input_to_checker_function(self, flag_values):
"""Given flag values, returns the input to be given to checker.
Args:
flag_values: flags.FlagValues, containing all flags.
Returns:
The input to be given to checker. The return type depends on the specific
validator.
"""
raise NotImplementedError('This method should be overloaded')
class SingleFlagValidator(Validator):
"""Validator behind register_validator() method.
Validates that a single flag passes its checker function. The checker function
takes the flag value and returns True (if value looks fine) or, if flag value
is not valid, either returns False or raises an Exception.
"""
def __init__(self, flag_name, checker, message):
"""Constructor.
Args:
flag_name: string, name of the flag.
checker: function to verify the validator.
input - value of the corresponding flag (string, boolean, etc).
output - bool, True if validator constraint is satisfied.
If constraint is not satisfied, it should either return False or
raise flags.ValidationError(desired_error_message).
message: str, error message to be shown to the user if validator's
condition is not satisfied.
"""
super(SingleFlagValidator, self).__init__(checker, message)
self.flag_name = flag_name
def get_flags_names(self):
return [self.flag_name]
def print_flags_with_values(self, flag_values):
return 'flag --%s=%s' % (self.flag_name, flag_values[self.flag_name].value)
def _get_input_to_checker_function(self, flag_values):
"""Given flag values, returns the input to be given to checker.
Args:
flag_values: flags.FlagValues, the FlagValues instance to get flags from.
Returns:
object, the input to be given to checker.
"""
return flag_values[self.flag_name].value
class MultiFlagsValidator(Validator):
"""Validator behind register_multi_flags_validator method.
Validates that flag values pass their common checker function. The checker
function takes flag values and returns True (if values look fine) or,
if values are not valid, either returns False or raises an Exception.
"""
def __init__(self, flag_names, checker, message):
"""Constructor.
Args:
flag_names: [str], containing names of the flags used by checker.
checker: function to verify the validator.
input - dict, with keys() being flag_names, and value for each
key being the value of the corresponding flag (string, boolean,
etc).
output - bool, True if validator constraint is satisfied.
If constraint is not satisfied, it should either return False or
raise flags.ValidationError(desired_error_message).
message: str, error message to be shown to the user if validator's
condition is not satisfied
"""
super(MultiFlagsValidator, self).__init__(checker, message)
self.flag_names = flag_names
def _get_input_to_checker_function(self, flag_values):
"""Given flag values, returns the input to be given to checker.
Args:
flag_values: flags.FlagValues, the FlagValues instance to get flags from.
Returns:
dict, with keys() being self.flag_names, and value for each key
being the value of the corresponding flag (string, boolean, etc).
"""
return dict([key, flag_values[key].value] for key in self.flag_names)
def print_flags_with_values(self, flag_values):
prefix = 'flags '
flags_with_values = []
for key in self.flag_names:
flags_with_values.append('%s=%s' % (key, flag_values[key].value))
return prefix + ', '.join(flags_with_values)
def get_flags_names(self):
return self.flag_names
abseil-py-2.1.0/absl/flags/argparse_flags.py 0000664 0000000 0000000 00000034225 14551576331 0020766 0 ustar 00root root 0000000 0000000 # Copyright 2018 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module provides argparse integration with absl.flags.
``argparse_flags.ArgumentParser`` is a drop-in replacement for
:class:`argparse.ArgumentParser`. It takes care of collecting and defining absl
flags in :mod:`argparse`.
Here is a simple example::
# Assume the following absl.flags is defined in another module:
#
# from absl import flags
# flags.DEFINE_string('echo', None, 'The echo message.')
#
parser = argparse_flags.ArgumentParser(
description='A demo of absl.flags and argparse integration.')
parser.add_argument('--header', help='Header message to print.')
# The parser will also accept the absl flag `--echo`.
# The `header` value is available as `args.header` just like a regular
# argparse flag. The absl flag `--echo` continues to be available via
# `absl.flags.FLAGS` if you want to access it.
args = parser.parse_args()
# Example usages:
# ./program --echo='A message.' --header='A header'
# ./program --header 'A header' --echo 'A message.'
Here is another example demonstrates subparsers::
parser = argparse_flags.ArgumentParser(description='A subcommands demo.')
parser.add_argument('--header', help='The header message to print.')
subparsers = parser.add_subparsers(help='The command to execute.')
roll_dice_parser = subparsers.add_parser(
'roll_dice', help='Roll a dice.',
# By default, absl flags can also be specified after the sub-command.
# To only allow them before sub-command, pass
# `inherited_absl_flags=None`.
inherited_absl_flags=None)
roll_dice_parser.add_argument('--num_faces', type=int, default=6)
roll_dice_parser.set_defaults(command=roll_dice)
shuffle_parser = subparsers.add_parser('shuffle', help='Shuffle inputs.')
shuffle_parser.add_argument(
'inputs', metavar='I', nargs='+', help='Inputs to shuffle.')
shuffle_parser.set_defaults(command=shuffle)
args = parser.parse_args(argv[1:])
args.command(args)
# Example usages:
# ./program --echo='A message.' roll_dice --num_faces=6
# ./program shuffle --echo='A message.' 1 2 3 4
There are several differences between :mod:`absl.flags` and
:mod:`~absl.flags.argparse_flags`:
1. Flags defined with absl.flags are parsed differently when using the
argparse parser. Notably:
1) absl.flags allows both single-dash and double-dash for any flag, and
doesn't distinguish them; argparse_flags only allows double-dash for
flag's regular name, and single-dash for flag's ``short_name``.
2) Boolean flags in absl.flags can be specified with ``--bool``,
``--nobool``, as well as ``--bool=true/false`` (though not recommended);
in argparse_flags, it only allows ``--bool``, ``--nobool``.
2. Help related flag differences:
1) absl.flags does not define help flags, absl.app does that; argparse_flags
defines help flags unless passed with ``add_help=False``.
2) absl.app supports ``--helpxml``; argparse_flags does not.
3) argparse_flags supports ``-h``; absl.app does not.
"""
import argparse
import sys
from absl import flags
_BUILT_IN_FLAGS = frozenset({
'help',
'helpshort',
'helpfull',
'helpxml',
'flagfile',
'undefok',
})
class ArgumentParser(argparse.ArgumentParser):
"""Custom ArgumentParser class to support special absl flags."""
def __init__(self, **kwargs):
"""Initializes ArgumentParser.
Args:
**kwargs: same as argparse.ArgumentParser, except:
1. It also accepts `inherited_absl_flags`: the absl flags to inherit.
The default is the global absl.flags.FLAGS instance. Pass None to
ignore absl flags.
2. The `prefix_chars` argument must be the default value '-'.
Raises:
ValueError: Raised when prefix_chars is not '-'.
"""
prefix_chars = kwargs.get('prefix_chars', '-')
if prefix_chars != '-':
raise ValueError(
'argparse_flags.ArgumentParser only supports "-" as the prefix '
'character, found "{}".'.format(prefix_chars))
# Remove inherited_absl_flags before calling super.
self._inherited_absl_flags = kwargs.pop('inherited_absl_flags', flags.FLAGS)
# Now call super to initialize argparse.ArgumentParser before calling
# add_argument in _define_absl_flags.
super(ArgumentParser, self).__init__(**kwargs)
if self.add_help:
# -h and --help are defined in super.
# Also add the --helpshort and --helpfull flags.
self.add_argument(
# Action 'help' defines a similar flag to -h/--help.
'--helpshort', action='help',
default=argparse.SUPPRESS, help=argparse.SUPPRESS)
self.add_argument(
'--helpfull', action=_HelpFullAction,
default=argparse.SUPPRESS, help='show full help message and exit')
if self._inherited_absl_flags is not None:
self.add_argument(
'--undefok', default=argparse.SUPPRESS, help=argparse.SUPPRESS)
self._define_absl_flags(self._inherited_absl_flags)
def parse_known_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]
if self._inherited_absl_flags is not None:
# Handle --flagfile.
# Explicitly specify force_gnu=True, since argparse behaves like
# gnu_getopt: flags can be specified after positional arguments.
args = self._inherited_absl_flags.read_flags_from_files(
args, force_gnu=True)
undefok_missing = object()
undefok = getattr(namespace, 'undefok', undefok_missing)
namespace, args = super(ArgumentParser, self).parse_known_args(
args, namespace)
# For Python <= 2.7.8: https://bugs.python.org/issue9351, a bug where
# sub-parsers don't preserve existing namespace attributes.
# Restore the undefok attribute if a sub-parser dropped it.
if undefok is not undefok_missing:
namespace.undefok = undefok
if self._inherited_absl_flags is not None:
# Handle --undefok. At this point, `args` only contains unknown flags,
# so it won't strip defined flags that are also specified with --undefok.
# For Python <= 2.7.8: https://bugs.python.org/issue9351, a bug where
# sub-parsers don't preserve existing namespace attributes. The undefok
# attribute might not exist because a subparser dropped it.
if hasattr(namespace, 'undefok'):
args = _strip_undefok_args(namespace.undefok, args)
# absl flags are not exposed in the Namespace object. See Namespace:
# https://docs.python.org/3/library/argparse.html#argparse.Namespace.
del namespace.undefok
self._inherited_absl_flags.mark_as_parsed()
try:
self._inherited_absl_flags.validate_all_flags()
except flags.IllegalFlagValueError as e:
self.error(str(e))
return namespace, args
def _define_absl_flags(self, absl_flags):
"""Defines flags from absl_flags."""
key_flags = set(absl_flags.get_key_flags_for_module(sys.argv[0]))
for name in absl_flags:
if name in _BUILT_IN_FLAGS:
# Do not inherit built-in flags.
continue
flag_instance = absl_flags[name]
# Each flags with short_name appears in FLAGS twice, so only define
# when the dictionary key is equal to the regular name.
if name == flag_instance.name:
# Suppress the flag in the help short message if it's not a main
# module's key flag.
suppress = flag_instance not in key_flags
self._define_absl_flag(flag_instance, suppress)
def _define_absl_flag(self, flag_instance, suppress):
"""Defines a flag from the flag_instance."""
flag_name = flag_instance.name
short_name = flag_instance.short_name
argument_names = ['--' + flag_name]
if short_name:
argument_names.insert(0, '-' + short_name)
if suppress:
helptext = argparse.SUPPRESS
else:
# argparse help string uses %-formatting. Escape the literal %'s.
helptext = flag_instance.help.replace('%', '%%')
if flag_instance.boolean:
# Only add the `no` form to the long name.
argument_names.append('--no' + flag_name)
self.add_argument(
*argument_names, action=_BooleanFlagAction, help=helptext,
metavar=flag_instance.name.upper(),
flag_instance=flag_instance)
else:
self.add_argument(
*argument_names, action=_FlagAction, help=helptext,
metavar=flag_instance.name.upper(),
flag_instance=flag_instance)
class _FlagAction(argparse.Action):
"""Action class for Abseil non-boolean flags."""
def __init__(
self,
option_strings,
dest,
help, # pylint: disable=redefined-builtin
metavar,
flag_instance,
default=argparse.SUPPRESS):
"""Initializes _FlagAction.
Args:
option_strings: See argparse.Action.
dest: Ignored. The flag is always defined with dest=argparse.SUPPRESS.
help: See argparse.Action.
metavar: See argparse.Action.
flag_instance: absl.flags.Flag, the absl flag instance.
default: Ignored. The flag always uses dest=argparse.SUPPRESS so it
doesn't affect the parsing result.
"""
del dest
self._flag_instance = flag_instance
super(_FlagAction, self).__init__(
option_strings=option_strings,
dest=argparse.SUPPRESS,
help=help,
metavar=metavar)
def __call__(self, parser, namespace, values, option_string=None):
"""See https://docs.python.org/3/library/argparse.html#action-classes."""
self._flag_instance.parse(values)
self._flag_instance.using_default_value = False
class _BooleanFlagAction(argparse.Action):
"""Action class for Abseil boolean flags."""
def __init__(
self,
option_strings,
dest,
help, # pylint: disable=redefined-builtin
metavar,
flag_instance,
default=argparse.SUPPRESS):
"""Initializes _BooleanFlagAction.
Args:
option_strings: See argparse.Action.
dest: Ignored. The flag is always defined with dest=argparse.SUPPRESS.
help: See argparse.Action.
metavar: See argparse.Action.
flag_instance: absl.flags.Flag, the absl flag instance.
default: Ignored. The flag always uses dest=argparse.SUPPRESS so it
doesn't affect the parsing result.
"""
del dest, default
self._flag_instance = flag_instance
flag_names = [self._flag_instance.name]
if self._flag_instance.short_name:
flag_names.append(self._flag_instance.short_name)
self._flag_names = frozenset(flag_names)
super(_BooleanFlagAction, self).__init__(
option_strings=option_strings,
dest=argparse.SUPPRESS,
nargs=0, # Does not accept values, only `--bool` or `--nobool`.
help=help,
metavar=metavar)
def __call__(self, parser, namespace, values, option_string=None):
"""See https://docs.python.org/3/library/argparse.html#action-classes."""
if not isinstance(values, list) or values:
raise ValueError('values must be an empty list.')
if option_string.startswith('--'):
option = option_string[2:]
else:
option = option_string[1:]
if option in self._flag_names:
self._flag_instance.parse('true')
else:
if not option.startswith('no') or option[2:] not in self._flag_names:
raise ValueError('invalid option_string: ' + option_string)
self._flag_instance.parse('false')
self._flag_instance.using_default_value = False
class _HelpFullAction(argparse.Action):
"""Action class for --helpfull flag."""
def __init__(self, option_strings, dest, default, help): # pylint: disable=redefined-builtin
"""Initializes _HelpFullAction.
Args:
option_strings: See argparse.Action.
dest: Ignored. The flag is always defined with dest=argparse.SUPPRESS.
default: Ignored.
help: See argparse.Action.
"""
del dest, default
super(_HelpFullAction, self).__init__(
option_strings=option_strings,
dest=argparse.SUPPRESS,
default=argparse.SUPPRESS,
nargs=0,
help=help)
def __call__(self, parser, namespace, values, option_string=None):
"""See https://docs.python.org/3/library/argparse.html#action-classes."""
# This only prints flags when help is not argparse.SUPPRESS.
# It includes user defined argparse flags, as well as main module's
# key absl flags. Other absl flags use argparse.SUPPRESS, so they aren't
# printed here.
parser.print_help()
absl_flags = parser._inherited_absl_flags # pylint: disable=protected-access
if absl_flags is not None:
modules = sorted(absl_flags.flags_by_module_dict())
main_module = sys.argv[0]
if main_module in modules:
# The main module flags are already printed in parser.print_help().
modules.remove(main_module)
print(absl_flags._get_help_for_modules( # pylint: disable=protected-access
modules, prefix='', include_special_flags=True))
parser.exit()
def _strip_undefok_args(undefok, args):
"""Returns a new list of args after removing flags in --undefok."""
if undefok:
undefok_names = set(name.strip() for name in undefok.split(','))
undefok_names |= set('no' + name for name in undefok_names)
# Remove undefok flags.
args = [arg for arg in args if not _is_undefok(arg, undefok_names)]
return args
def _is_undefok(arg, undefok_names):
"""Returns whether we can ignore arg based on a set of undefok flag names."""
if not arg.startswith('-'):
return False
if arg.startswith('--'):
arg_without_dash = arg[2:]
else:
arg_without_dash = arg[1:]
if '=' in arg_without_dash:
name, _ = arg_without_dash.split('=', 1)
else:
name = arg_without_dash
if name in undefok_names:
return True
return False
abseil-py-2.1.0/absl/flags/tests/ 0000775 0000000 0000000 00000000000 14551576331 0016570 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/absl/flags/tests/__init__.py 0000664 0000000 0000000 00000001110 14551576331 0020672 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
abseil-py-2.1.0/absl/flags/tests/_argument_parser_test.py 0000664 0000000 0000000 00000014544 14551576331 0023546 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Additional tests for flag argument parsers.
Most of the argument parsers are covered in the flags_test.py.
"""
import enum
from absl.flags import _argument_parser
from absl.testing import absltest
from absl.testing import parameterized
class ArgumentParserTest(absltest.TestCase):
def test_instance_cache(self):
parser1 = _argument_parser.FloatParser()
parser2 = _argument_parser.FloatParser()
self.assertIs(parser1, parser2)
def test_parse_wrong_type(self):
parser = _argument_parser.ArgumentParser()
with self.assertRaises(TypeError):
parser.parse(0) # type: ignore
if bytes is not str:
# In PY3, it does not accept bytes.
with self.assertRaises(TypeError):
parser.parse(b'') # type: ignore
class BooleanParserTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.parser = _argument_parser.BooleanParser()
def test_parse_bytes(self):
with self.assertRaises(TypeError):
self.parser.parse(b'true') # type: ignore
def test_parse_str(self):
self.assertTrue(self.parser.parse('true'))
def test_parse_unicode(self):
self.assertTrue(self.parser.parse(u'true'))
def test_parse_wrong_type(self):
with self.assertRaises(TypeError):
self.parser.parse(1.234) # type: ignore
def test_parse_str_false(self):
self.assertFalse(self.parser.parse('false'))
def test_parse_integer(self):
self.assertTrue(self.parser.parse(1))
def test_parse_invalid_integer(self):
with self.assertRaises(ValueError):
self.parser.parse(-1)
def test_parse_invalid_str(self):
with self.assertRaises(ValueError):
self.parser.parse('nottrue')
class FloatParserTest(absltest.TestCase):
def setUp(self):
self.parser = _argument_parser.FloatParser()
def test_parse_string(self):
self.assertEqual(1.5, self.parser.parse('1.5'))
def test_parse_wrong_type(self):
with self.assertRaises(TypeError):
self.parser.parse(False) # type: ignore
class IntegerParserTest(absltest.TestCase):
def setUp(self):
self.parser = _argument_parser.IntegerParser()
def test_parse_string(self):
self.assertEqual(1, self.parser.parse('1'))
def test_parse_wrong_type(self):
with self.assertRaises(TypeError):
self.parser.parse(1e2) # type: ignore
with self.assertRaises(TypeError):
self.parser.parse(False) # type: ignore
class EnumParserTest(absltest.TestCase):
def test_empty_values(self):
with self.assertRaises(ValueError):
_argument_parser.EnumParser([])
def test_parse(self):
parser = _argument_parser.EnumParser(['apple', 'banana'])
self.assertEqual('apple', parser.parse('apple'))
def test_parse_not_found(self):
parser = _argument_parser.EnumParser(['apple', 'banana'])
with self.assertRaises(ValueError):
parser.parse('orange')
class Fruit(enum.Enum):
APPLE = 1
BANANA = 2
class EmptyEnum(enum.Enum):
pass
class MixedCaseEnum(enum.Enum):
APPLE = 1
BANANA = 2
apple = 3
class EnumClassParserTest(parameterized.TestCase):
def test_requires_enum(self):
with self.assertRaises(TypeError):
_argument_parser.EnumClassParser(['apple', 'banana']) # type: ignore
def test_requires_non_empty_enum_class(self):
with self.assertRaises(ValueError):
_argument_parser.EnumClassParser(EmptyEnum)
def test_case_sensitive_rejects_duplicates(self):
unused_normal_parser = _argument_parser.EnumClassParser(MixedCaseEnum)
with self.assertRaisesRegex(ValueError, 'Duplicate.+apple'):
_argument_parser.EnumClassParser(MixedCaseEnum, case_sensitive=False)
def test_parse_string(self):
parser = _argument_parser.EnumClassParser(Fruit)
self.assertEqual(Fruit.APPLE, parser.parse('APPLE'))
def test_parse_string_case_sensitive(self):
parser = _argument_parser.EnumClassParser(Fruit)
with self.assertRaises(ValueError):
parser.parse('apple')
@parameterized.parameters('APPLE', 'apple', 'Apple')
def test_parse_string_case_insensitive(self, value):
parser = _argument_parser.EnumClassParser(Fruit, case_sensitive=False)
self.assertIs(Fruit.APPLE, parser.parse(value))
def test_parse_literal(self):
parser = _argument_parser.EnumClassParser(Fruit)
self.assertEqual(Fruit.APPLE, parser.parse(Fruit.APPLE))
def test_parse_not_found(self):
parser = _argument_parser.EnumClassParser(Fruit)
with self.assertRaises(ValueError):
parser.parse('ORANGE')
@parameterized.parameters((Fruit.BANANA, False, 'BANANA'),
(Fruit.BANANA, True, 'banana'))
def test_serialize_parse(self, value, lowercase, expected):
serializer = _argument_parser.EnumClassSerializer(lowercase=lowercase)
parser = _argument_parser.EnumClassParser(
Fruit, case_sensitive=not lowercase)
serialized = serializer.serialize(value)
self.assertEqual(serialized, expected)
self.assertEqual(value, parser.parse(expected))
class SerializerTest(parameterized.TestCase):
def test_csv_serializer(self):
serializer = _argument_parser.CsvListSerializer('+')
self.assertEqual(serializer.serialize(['foo', 'bar']), 'foo+bar')
@parameterized.parameters([
dict(lowercase=False, expected='APPLE+BANANA'),
dict(lowercase=True, expected='apple+banana'),
])
def test_enum_class_list_serializer(self, lowercase, expected):
values = [Fruit.APPLE, Fruit.BANANA]
serializer = _argument_parser.EnumClassListSerializer(
list_sep='+', lowercase=lowercase)
serialized = serializer.serialize(values)
self.assertEqual(expected, serialized)
class HelperFunctionsTest(absltest.TestCase):
def test_is_integer_type(self):
self.assertTrue(_argument_parser._is_integer_type(1))
# Note that isinstance(False, int) == True.
self.assertFalse(_argument_parser._is_integer_type(False))
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/_flag_test.py 0000664 0000000 0000000 00000020453 14551576331 0021255 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Additional tests for Flag classes.
Most of the Flag classes are covered in the flags_test.py.
"""
import copy
import enum
import pickle
from absl.flags import _argument_parser
from absl.flags import _exceptions
from absl.flags import _flag
from absl.testing import absltest
from absl.testing import parameterized
class FlagTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.flag = _flag.Flag(
_argument_parser.ArgumentParser(),
_argument_parser.ArgumentSerializer(),
'fruit', 'apple', 'help')
def test_default_unparsed(self):
flag = _flag.Flag(
_argument_parser.ArgumentParser(),
_argument_parser.ArgumentSerializer(),
'fruit', 'apple', 'help')
self.assertEqual('apple', flag.default_unparsed)
flag = _flag.Flag(
_argument_parser.IntegerParser(),
_argument_parser.ArgumentSerializer(),
'number', '1', 'help')
self.assertEqual('1', flag.default_unparsed)
flag = _flag.Flag(
_argument_parser.IntegerParser(),
_argument_parser.ArgumentSerializer(),
'number', 1, 'help')
self.assertEqual(1, flag.default_unparsed)
def test_no_truthiness(self):
with self.assertRaises(TypeError):
if self.flag:
self.fail('Flag instances must raise rather than be truthy.')
def test_set_default_overrides_current_value(self):
self.assertEqual('apple', self.flag.value)
self.flag._set_default('orange')
self.assertEqual('orange', self.flag.value)
def test_set_default_overrides_current_value_when_not_using_default(self):
self.flag.using_default_value = False
self.assertEqual('apple', self.flag.value)
self.flag._set_default('orange')
self.assertEqual('apple', self.flag.value)
def test_pickle(self):
with self.assertRaisesRegex(TypeError, "can't pickle Flag objects"):
pickle.dumps(self.flag)
def test_copy(self):
self.flag.value = 'orange'
with self.assertRaisesRegex(TypeError,
'Flag does not support shallow copies'):
copy.copy(self.flag)
flag2 = copy.deepcopy(self.flag)
self.assertEqual(flag2.value, 'orange')
flag2.value = 'mango'
self.assertEqual(flag2.value, 'mango')
self.assertEqual(self.flag.value, 'orange')
class BooleanFlagTest(parameterized.TestCase):
@parameterized.parameters(('', '(no help available)'),
('Is my test brilliant?', 'Is my test brilliant?'))
def test_help_text(self, helptext_input, helptext_output):
f = _flag.BooleanFlag('a_bool', False, helptext_input)
self.assertEqual(helptext_output, f.help)
class EnumFlagTest(parameterized.TestCase):
@parameterized.parameters(
('', ': (no help available)'),
('Type of fruit.', ': Type of fruit.'))
def test_help_text(self, helptext_input, helptext_output):
f = _flag.EnumFlag('fruit', 'apple', helptext_input, ['apple', 'orange'])
self.assertEqual(helptext_output, f.help)
def test_empty_values(self):
with self.assertRaises(ValueError):
_flag.EnumFlag('fruit', None, 'help', [])
class Fruit(enum.Enum):
APPLE = 1
ORANGE = 2
class EmptyEnum(enum.Enum):
pass
class EnumClassFlagTest(parameterized.TestCase):
@parameterized.parameters(
('', ': (no help available)'),
('Type of fruit.', ': Type of fruit.'))
def test_help_text_case_insensitive(self, helptext_input, helptext_output):
f = _flag.EnumClassFlag('fruit', None, helptext_input, Fruit)
self.assertEqual(helptext_output, f.help)
@parameterized.parameters(
('', ': (no help available)'),
('Type of fruit.', ': Type of fruit.'))
def test_help_text_case_sensitive(self, helptext_input, helptext_output):
f = _flag.EnumClassFlag(
'fruit', None, helptext_input, Fruit, case_sensitive=True)
self.assertEqual(helptext_output, f.help)
def test_requires_enum(self):
with self.assertRaises(TypeError):
_flag.EnumClassFlag('fruit', None, 'help', ['apple', 'orange']) # type: ignore
def test_requires_non_empty_enum_class(self):
with self.assertRaises(ValueError):
_flag.EnumClassFlag('empty', None, 'help', EmptyEnum)
def test_accepts_literal_default(self):
f = _flag.EnumClassFlag('fruit', Fruit.APPLE, 'A sample enum flag.', Fruit)
self.assertEqual(Fruit.APPLE, f.value)
def test_accepts_string_default(self):
f = _flag.EnumClassFlag('fruit', 'ORANGE', 'A sample enum flag.', Fruit)
self.assertEqual(Fruit.ORANGE, f.value)
def test_case_sensitive_rejects_default_with_wrong_case(self):
with self.assertRaises(_exceptions.IllegalFlagValueError):
_flag.EnumClassFlag(
'fruit', 'oranGe', 'A sample enum flag.', Fruit, case_sensitive=True)
def test_case_insensitive_accepts_string_default(self):
f = _flag.EnumClassFlag(
'fruit', 'oranGe', 'A sample enum flag.', Fruit, case_sensitive=False)
self.assertEqual(Fruit.ORANGE, f.value)
def test_default_value_does_not_exist(self):
with self.assertRaises(_exceptions.IllegalFlagValueError):
_flag.EnumClassFlag('fruit', 'BANANA', 'help', Fruit)
class MultiEnumClassFlagTest(parameterized.TestCase):
@parameterized.named_parameters(
('NoHelpSupplied', '', ': (no help available);\n ' +
'repeat this option to specify a list of values', False),
('WithHelpSupplied', 'Type of fruit.',
': Type of fruit.;\n ' +
'repeat this option to specify a list of values', True))
def test_help_text(self, helptext_input, helptext_output, case_sensitive):
f = _flag.MultiEnumClassFlag(
'fruit', None, helptext_input, Fruit, case_sensitive=case_sensitive)
self.assertEqual(helptext_output, f.help)
def test_requires_enum(self):
with self.assertRaises(TypeError):
_flag.MultiEnumClassFlag('fruit', None, 'help', ['apple', 'orange']) # type: ignore
def test_requires_non_empty_enum_class(self):
with self.assertRaises(ValueError):
_flag.MultiEnumClassFlag('empty', None, 'help', EmptyEnum)
def test_rejects_wrong_case_when_case_sensitive(self):
with self.assertRaisesRegex(_exceptions.IllegalFlagValueError,
''):
_flag.MultiEnumClassFlag(
'fruit', ['APPLE', 'Orange'],
'A sample enum flag.',
Fruit,
case_sensitive=True)
def test_accepts_case_insensitive(self):
f = _flag.MultiEnumClassFlag('fruit', ['apple', 'APPLE'],
'A sample enum flag.', Fruit)
self.assertListEqual([Fruit.APPLE, Fruit.APPLE], f.value)
def test_accepts_literal_default(self):
f = _flag.MultiEnumClassFlag('fruit', Fruit.APPLE, 'A sample enum flag.',
Fruit)
self.assertListEqual([Fruit.APPLE], f.value)
def test_accepts_list_of_literal_default(self):
f = _flag.MultiEnumClassFlag('fruit', [Fruit.APPLE, Fruit.ORANGE],
'A sample enum flag.', Fruit)
self.assertListEqual([Fruit.APPLE, Fruit.ORANGE], f.value)
def test_accepts_string_default(self):
f = _flag.MultiEnumClassFlag('fruit', 'ORANGE', 'A sample enum flag.',
Fruit)
self.assertListEqual([Fruit.ORANGE], f.value)
def test_accepts_list_of_string_default(self):
f = _flag.MultiEnumClassFlag('fruit', ['ORANGE', 'APPLE'],
'A sample enum flag.', Fruit)
self.assertListEqual([Fruit.ORANGE, Fruit.APPLE], f.value)
def test_default_value_does_not_exist(self):
with self.assertRaisesRegex(_exceptions.IllegalFlagValueError,
''):
_flag.MultiEnumClassFlag('fruit', 'BANANA', 'help', Fruit)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/_flagvalues_test.py 0000664 0000000 0000000 00000100640 14551576331 0022472 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for flags.FlagValues class."""
import collections
import copy
import pickle
import types
from unittest import mock
from absl import logging
from absl.flags import _defines
from absl.flags import _exceptions
from absl.flags import _flagvalues
from absl.flags import _helpers
from absl.flags import _validators
from absl.flags.tests import module_foo
from absl.testing import absltest
from absl.testing import parameterized
class FlagValuesTest(absltest.TestCase):
def test_bool_flags(self):
for arg, expected in (('--nothing', True),
('--nothing=true', True),
('--nothing=false', False),
('--nonothing', False)):
fv = _flagvalues.FlagValues()
_defines.DEFINE_boolean('nothing', None, '', flag_values=fv)
fv(('./program', arg))
self.assertIs(expected, fv.nothing)
for arg in ('--nonothing=true', '--nonothing=false'):
fv = _flagvalues.FlagValues()
_defines.DEFINE_boolean('nothing', None, '', flag_values=fv)
with self.assertRaises(ValueError):
fv(('./program', arg))
def test_boolean_flag_parser_gets_string_argument(self):
for arg, expected in (('--nothing', 'true'),
('--nothing=true', 'true'),
('--nothing=false', 'false'),
('--nonothing', 'false')):
fv = _flagvalues.FlagValues()
_defines.DEFINE_boolean('nothing', None, '', flag_values=fv)
with mock.patch.object(fv['nothing'].parser, 'parse') as mock_parse:
fv(('./program', arg))
mock_parse.assert_called_once_with(expected)
def test_unregistered_flags_are_cleaned_up(self):
fv = _flagvalues.FlagValues()
module, module_name = _helpers.get_calling_module_object_and_name()
# Define first flag.
_defines.DEFINE_integer('cores', 4, '', flag_values=fv, short_name='c')
old_cores_flag = fv['cores']
fv.register_key_flag_for_module(module_name, old_cores_flag)
self.assertEqual(fv.flags_by_module_dict(),
{module_name: [old_cores_flag]})
self.assertEqual(fv.flags_by_module_id_dict(),
{id(module): [old_cores_flag]})
self.assertEqual(fv.key_flags_by_module_dict(),
{module_name: [old_cores_flag]})
# Redefine the same flag.
_defines.DEFINE_integer(
'cores', 4, '', flag_values=fv, short_name='c', allow_override=True)
new_cores_flag = fv['cores']
self.assertNotEqual(old_cores_flag, new_cores_flag)
self.assertEqual(fv.flags_by_module_dict(),
{module_name: [new_cores_flag]})
self.assertEqual(fv.flags_by_module_id_dict(),
{id(module): [new_cores_flag]})
# old_cores_flag is removed from key flags, and the new_cores_flag is
# not automatically added because it must be registered explicitly.
self.assertEqual(fv.key_flags_by_module_dict(), {module_name: []})
# Define a new flag but with the same short_name.
_defines.DEFINE_integer(
'changelist',
0,
'',
flag_values=fv,
short_name='c',
allow_override=True)
old_changelist_flag = fv['changelist']
fv.register_key_flag_for_module(module_name, old_changelist_flag)
# The short named flag -c is overridden to be the old_changelist_flag.
self.assertEqual(fv['c'], old_changelist_flag)
self.assertNotEqual(fv['c'], new_cores_flag)
self.assertEqual(fv.flags_by_module_dict(),
{module_name: [new_cores_flag, old_changelist_flag]})
self.assertEqual(fv.flags_by_module_id_dict(),
{id(module): [new_cores_flag, old_changelist_flag]})
self.assertEqual(fv.key_flags_by_module_dict(),
{module_name: [old_changelist_flag]})
# Define a flag only with the same long name.
_defines.DEFINE_integer(
'changelist',
0,
'',
flag_values=fv,
short_name='l',
allow_override=True)
new_changelist_flag = fv['changelist']
self.assertNotEqual(old_changelist_flag, new_changelist_flag)
self.assertEqual(fv.flags_by_module_dict(),
{module_name: [new_cores_flag,
old_changelist_flag,
new_changelist_flag]})
self.assertEqual(fv.flags_by_module_id_dict(),
{id(module): [new_cores_flag,
old_changelist_flag,
new_changelist_flag]})
self.assertEqual(fv.key_flags_by_module_dict(),
{module_name: [old_changelist_flag]})
# Delete the new changelist's long name, it should still be registered
# because of its short name.
del fv.changelist
self.assertNotIn('changelist', fv)
self.assertEqual(fv.flags_by_module_dict(),
{module_name: [new_cores_flag,
old_changelist_flag,
new_changelist_flag]})
self.assertEqual(fv.flags_by_module_id_dict(),
{id(module): [new_cores_flag,
old_changelist_flag,
new_changelist_flag]})
self.assertEqual(fv.key_flags_by_module_dict(),
{module_name: [old_changelist_flag]})
# Delete the new changelist's short name, it should be removed.
del fv.l
self.assertNotIn('l', fv)
self.assertEqual(fv.flags_by_module_dict(),
{module_name: [new_cores_flag,
old_changelist_flag]})
self.assertEqual(fv.flags_by_module_id_dict(),
{id(module): [new_cores_flag,
old_changelist_flag]})
self.assertEqual(fv.key_flags_by_module_dict(),
{module_name: [old_changelist_flag]})
def _test_find_module_or_id_defining_flag(self, test_id):
"""Tests for find_module_defining_flag and find_module_id_defining_flag.
Args:
test_id: True to test find_module_id_defining_flag, False to test
find_module_defining_flag.
"""
fv = _flagvalues.FlagValues()
current_module, current_module_name = (
_helpers.get_calling_module_object_and_name())
alt_module_name = _flagvalues.__name__
if test_id:
current_module_or_id = id(current_module)
alt_module_or_id = id(_flagvalues)
testing_fn = fv.find_module_id_defining_flag
else:
current_module_or_id = current_module_name
alt_module_or_id = alt_module_name
testing_fn = fv.find_module_defining_flag
# Define first flag.
_defines.DEFINE_integer('cores', 4, '', flag_values=fv, short_name='c')
module_or_id_cores = testing_fn('cores')
self.assertEqual(module_or_id_cores, current_module_or_id)
module_or_id_c = testing_fn('c')
self.assertEqual(module_or_id_c, current_module_or_id)
# Redefine the same flag in another module.
_defines.DEFINE_integer(
'cores',
4,
'',
flag_values=fv,
module_name=alt_module_name,
short_name='c',
allow_override=True)
module_or_id_cores = testing_fn('cores')
self.assertEqual(module_or_id_cores, alt_module_or_id)
module_or_id_c = testing_fn('c')
self.assertEqual(module_or_id_c, alt_module_or_id)
# Define a new flag but with the same short_name.
_defines.DEFINE_integer(
'changelist',
0,
'',
flag_values=fv,
short_name='c',
allow_override=True)
module_or_id_cores = testing_fn('cores')
self.assertEqual(module_or_id_cores, alt_module_or_id)
module_or_id_changelist = testing_fn('changelist')
self.assertEqual(module_or_id_changelist, current_module_or_id)
module_or_id_c = testing_fn('c')
self.assertEqual(module_or_id_c, current_module_or_id)
# Define a flag in another module only with the same long name.
_defines.DEFINE_integer(
'changelist',
0,
'',
flag_values=fv,
module_name=alt_module_name,
short_name='l',
allow_override=True)
module_or_id_cores = testing_fn('cores')
self.assertEqual(module_or_id_cores, alt_module_or_id)
module_or_id_changelist = testing_fn('changelist')
self.assertEqual(module_or_id_changelist, alt_module_or_id)
module_or_id_c = testing_fn('c')
self.assertEqual(module_or_id_c, current_module_or_id)
module_or_id_l = testing_fn('l')
self.assertEqual(module_or_id_l, alt_module_or_id)
# Delete the changelist flag, its short name should still be registered.
del fv.changelist
module_or_id_changelist = testing_fn('changelist')
self.assertIsNone(module_or_id_changelist)
module_or_id_c = testing_fn('c')
self.assertEqual(module_or_id_c, current_module_or_id)
module_or_id_l = testing_fn('l')
self.assertEqual(module_or_id_l, alt_module_or_id)
def test_find_module_defining_flag(self):
self._test_find_module_or_id_defining_flag(test_id=False)
def test_find_module_id_defining_flag(self):
self._test_find_module_or_id_defining_flag(test_id=True)
def test_set_default(self):
fv = _flagvalues.FlagValues()
fv.mark_as_parsed()
with self.assertRaises(_exceptions.UnrecognizedFlagError):
fv.set_default('changelist', 1)
_defines.DEFINE_integer('changelist', 0, 'help', flag_values=fv)
self.assertEqual(0, fv.changelist)
fv.set_default('changelist', 2)
self.assertEqual(2, fv.changelist)
def test_default_gnu_getopt_value(self):
self.assertTrue(_flagvalues.FlagValues().is_gnu_getopt())
def test_known_only_flags_in_gnustyle(self):
def run_test(argv, defined_py_flags, expected_argv):
fv = _flagvalues.FlagValues()
fv.set_gnu_getopt(True)
for f in defined_py_flags:
if f.startswith('b'):
_defines.DEFINE_boolean(f, False, 'help', flag_values=fv)
else:
_defines.DEFINE_string(f, 'default', 'help', flag_values=fv)
output_argv = fv(argv, known_only=True)
self.assertEqual(expected_argv, output_argv)
run_test(
argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
defined_py_flags=[],
expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '))
run_test(
argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
defined_py_flags=['f1'],
expected_argv='0 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '))
run_test(
argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
defined_py_flags=['f2'],
expected_argv='0 --f1=v1 cmd --b1 --f3 v3 --nob2'.split(' '))
run_test(
argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
defined_py_flags=['b1'],
expected_argv='0 --f1=v1 cmd --f2 v2 --f3 v3 --nob2'.split(' '))
run_test(
argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
defined_py_flags=['f3'],
expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --nob2'.split(' '))
run_test(
argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3 --nob2'.split(' '),
defined_py_flags=['b2'],
expected_argv='0 --f1=v1 cmd --f2 v2 --b1 --f3 v3'.split(' '))
run_test(
argv=('0 --f1=v1 cmd --undefok=f1 --f2 v2 --b1 '
'--f3 v3 --nob2').split(' '),
defined_py_flags=['b2'],
expected_argv='0 cmd --f2 v2 --b1 --f3 v3'.split(' '))
run_test(
argv=('0 --f1=v1 cmd --undefok f1,f2 --f2 v2 --b1 '
'--f3 v3 --nob2').split(' '),
defined_py_flags=['b2'],
# Note v2 is preserved here, since undefok requires the flag being
# specified in the form of --flag=value.
expected_argv='0 cmd v2 --b1 --f3 v3'.split(' '))
def test_invalid_flag_name(self):
with self.assertRaises(_exceptions.Error):
_defines.DEFINE_boolean('test ', 0, '')
with self.assertRaises(_exceptions.Error):
_defines.DEFINE_boolean(' test', 0, '')
with self.assertRaises(_exceptions.Error):
_defines.DEFINE_boolean('te st', 0, '')
with self.assertRaises(_exceptions.Error):
_defines.DEFINE_boolean('', 0, '')
with self.assertRaises(_exceptions.Error):
_defines.DEFINE_boolean(1, 0, '') # type: ignore
def test_len(self):
fv = _flagvalues.FlagValues()
self.assertEmpty(fv)
self.assertFalse(fv)
_defines.DEFINE_boolean('boolean', False, 'help', flag_values=fv)
self.assertLen(fv, 1)
self.assertTrue(fv)
_defines.DEFINE_boolean(
'bool', False, 'help', short_name='b', flag_values=fv)
self.assertLen(fv, 3)
self.assertTrue(fv)
def test_pickle(self):
fv = _flagvalues.FlagValues()
with self.assertRaisesRegex(TypeError, "can't pickle FlagValues"):
pickle.dumps(fv)
def test_copy(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_integer('answer', 0, 'help', flag_values=fv)
fv(['', '--answer=1'])
with self.assertRaisesRegex(TypeError,
'FlagValues does not support shallow copies'):
copy.copy(fv)
fv2 = copy.deepcopy(fv)
self.assertEqual(fv2.answer, 1)
fv2.answer = 42
self.assertEqual(fv2.answer, 42)
self.assertEqual(fv.answer, 1)
def test_conflicting_flags(self):
fv = _flagvalues.FlagValues()
with self.assertRaises(_exceptions.FlagNameConflictsWithMethodError):
_defines.DEFINE_boolean('is_gnu_getopt', False, 'help', flag_values=fv)
_defines.DEFINE_boolean(
'is_gnu_getopt',
False,
'help',
flag_values=fv,
allow_using_method_names=True)
self.assertFalse(fv['is_gnu_getopt'].value)
self.assertIsInstance(fv.is_gnu_getopt, types.MethodType)
def test_get_flags_for_module(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_string('foo', None, 'help', flag_values=fv)
module_foo.define_flags(fv)
flags = fv.get_flags_for_module('__main__')
self.assertEqual({'foo'}, {flag.name for flag in flags})
flags = fv.get_flags_for_module(module_foo)
self.assertEqual({'tmod_foo_bool', 'tmod_foo_int', 'tmod_foo_str'},
{flag.name for flag in flags})
def test_get_help(self):
fv = _flagvalues.FlagValues()
self.assertMultiLineEqual('''\
--flagfile: Insert flag definitions from the given file into the command line.
(default: '')
--undefok: comma-separated list of flag names that it is okay to specify on the
command line even if the program does not define a flag with that name.
IMPORTANT: flags in this list that have arguments MUST use the --flag=value
format.
(default: '')''', fv.get_help())
module_foo.define_flags(fv)
self.assertMultiLineEqual('''
absl.flags.tests.module_bar:
--tmod_bar_t: Sample int flag.
(default: '4')
(an integer)
--tmod_bar_u: Sample int flag.
(default: '5')
(an integer)
--tmod_bar_v: Sample int flag.
(default: '6')
(an integer)
--[no]tmod_bar_x: Boolean flag.
(default: 'true')
--tmod_bar_y: String flag.
(default: 'default')
--[no]tmod_bar_z: Another boolean flag from module bar.
(default: 'false')
absl.flags.tests.module_foo:
--[no]tmod_foo_bool: Boolean flag from module foo.
(default: 'true')
--tmod_foo_int: Sample int flag.
(default: '3')
(an integer)
--tmod_foo_str: String flag.
(default: 'default')
absl.flags:
--flagfile: Insert flag definitions from the given file into the command line.
(default: '')
--undefok: comma-separated list of flag names that it is okay to specify on
the command line even if the program does not define a flag with that name.
IMPORTANT: flags in this list that have arguments MUST use the --flag=value
format.
(default: '')''', fv.get_help())
self.assertMultiLineEqual('''
xxxxabsl.flags.tests.module_bar:
xxxx --tmod_bar_t: Sample int flag.
xxxx (default: '4')
xxxx (an integer)
xxxx --tmod_bar_u: Sample int flag.
xxxx (default: '5')
xxxx (an integer)
xxxx --tmod_bar_v: Sample int flag.
xxxx (default: '6')
xxxx (an integer)
xxxx --[no]tmod_bar_x: Boolean flag.
xxxx (default: 'true')
xxxx --tmod_bar_y: String flag.
xxxx (default: 'default')
xxxx --[no]tmod_bar_z: Another boolean flag from module bar.
xxxx (default: 'false')
xxxxabsl.flags.tests.module_foo:
xxxx --[no]tmod_foo_bool: Boolean flag from module foo.
xxxx (default: 'true')
xxxx --tmod_foo_int: Sample int flag.
xxxx (default: '3')
xxxx (an integer)
xxxx --tmod_foo_str: String flag.
xxxx (default: 'default')
xxxxabsl.flags:
xxxx --flagfile: Insert flag definitions from the given file into the command
xxxx line.
xxxx (default: '')
xxxx --undefok: comma-separated list of flag names that it is okay to specify
xxxx on the command line even if the program does not define a flag with that
xxxx name. IMPORTANT: flags in this list that have arguments MUST use the
xxxx --flag=value format.
xxxx (default: '')''', fv.get_help(prefix='xxxx'))
self.assertMultiLineEqual('''
absl.flags.tests.module_bar:
--tmod_bar_t: Sample int flag.
(default: '4')
(an integer)
--tmod_bar_u: Sample int flag.
(default: '5')
(an integer)
--tmod_bar_v: Sample int flag.
(default: '6')
(an integer)
--[no]tmod_bar_x: Boolean flag.
(default: 'true')
--tmod_bar_y: String flag.
(default: 'default')
--[no]tmod_bar_z: Another boolean flag from module bar.
(default: 'false')
absl.flags.tests.module_foo:
--[no]tmod_foo_bool: Boolean flag from module foo.
(default: 'true')
--tmod_foo_int: Sample int flag.
(default: '3')
(an integer)
--tmod_foo_str: String flag.
(default: 'default')''', fv.get_help(include_special_flags=False))
def test_str(self):
fv = _flagvalues.FlagValues()
self.assertEqual(str(fv), fv.get_help())
module_foo.define_flags(fv)
self.assertEqual(str(fv), fv.get_help())
def test_empty_argv(self):
fv = _flagvalues.FlagValues()
with self.assertRaises(ValueError):
fv([])
def test_invalid_argv(self):
fv = _flagvalues.FlagValues()
with self.assertRaises(TypeError):
fv('./program') # type: ignore
with self.assertRaises(TypeError):
fv(b'./program') # type: ignore
def test_flags_dir(self):
flag_values = _flagvalues.FlagValues()
flag_name1 = 'bool_flag'
flag_name2 = 'string_flag'
flag_name3 = 'float_flag'
description = 'Description'
_defines.DEFINE_boolean(
flag_name1, None, description, flag_values=flag_values)
_defines.DEFINE_string(
flag_name2, None, description, flag_values=flag_values)
self.assertEqual(sorted([flag_name1, flag_name2]), dir(flag_values))
_defines.DEFINE_float(
flag_name3, None, description, flag_values=flag_values)
self.assertEqual(
sorted([flag_name1, flag_name2, flag_name3]), dir(flag_values))
def test_flags_into_string_deterministic(self):
flag_values = _flagvalues.FlagValues()
_defines.DEFINE_string(
'fa', 'x', '', flag_values=flag_values, module_name='mb')
_defines.DEFINE_string(
'fb', 'x', '', flag_values=flag_values, module_name='mb')
_defines.DEFINE_string(
'fc', 'x', '', flag_values=flag_values, module_name='ma')
_defines.DEFINE_string(
'fd', 'x', '', flag_values=flag_values, module_name='ma')
expected = ('--fc=x\n'
'--fd=x\n'
'--fa=x\n'
'--fb=x\n')
flags_by_module_items = sorted(
flag_values.flags_by_module_dict().items(), reverse=True)
for _, module_flags in flags_by_module_items:
module_flags.sort(reverse=True)
flag_values.__dict__['__flags_by_module'] = collections.OrderedDict(
flags_by_module_items)
actual = flag_values.flags_into_string()
self.assertEqual(expected, actual)
def test_validate_all_flags(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_string('name', None, '', flag_values=fv)
_validators.mark_flag_as_required('name', flag_values=fv)
with self.assertRaises(_exceptions.IllegalFlagValueError):
fv.validate_all_flags()
fv.name = 'test'
fv.validate_all_flags()
class FlagValuesLoggingTest(absltest.TestCase):
"""Test to make sure logging.* functions won't recurse.
Logging may and does happen before flags initialization. We need to make
sure that any warnings trown by flagvalues do not result in unlimited
recursion.
"""
def test_logging_do_not_recurse(self):
logging.info('test info')
try:
raise ValueError('test exception')
except ValueError:
logging.exception('test message')
class FlagSubstrMatchingTests(parameterized.TestCase):
"""Tests related to flag substring matching."""
def _get_test_flag_values(self):
"""Get a _flagvalues.FlagValues() instance, set up for tests."""
flag_values = _flagvalues.FlagValues()
_defines.DEFINE_string('strf', '', '', flag_values=flag_values)
_defines.DEFINE_boolean('boolf', 0, '', flag_values=flag_values)
return flag_values
# Test cases that should always make parsing raise an error.
# Tuples of strings with the argv to use.
FAIL_TEST_CASES = [
('./program', '--boo', '0'),
('./program', '--boo=true', '0'),
('./program', '--boo=0'),
('./program', '--noboo'),
('./program', '--st=blah'),
('./program', '--st=de'),
('./program', '--st=blah', '--boo'),
('./program', '--st=blah', 'unused'),
('./program', '--st=--blah'),
('./program', '--st', '--blah'),
]
@parameterized.parameters(FAIL_TEST_CASES)
def test_raise(self, *argv):
"""Test that raising works."""
fv = self._get_test_flag_values()
with self.assertRaises(_exceptions.UnrecognizedFlagError):
fv(argv)
@parameterized.parameters(
FAIL_TEST_CASES + [('./program', 'unused', '--st=blah')])
def test_gnu_getopt_raise(self, *argv):
"""Test that raising works when combined with GNU-style getopt."""
fv = self._get_test_flag_values()
fv.set_gnu_getopt()
with self.assertRaises(_exceptions.UnrecognizedFlagError):
fv(argv)
class SettingUnknownFlagTest(absltest.TestCase):
def setUp(self):
super(SettingUnknownFlagTest, self).setUp()
self.setter_called = 0
def set_undef(self, unused_name, unused_val):
self.setter_called += 1
def test_raise_on_undefined(self):
new_flags = _flagvalues.FlagValues()
with self.assertRaises(_exceptions.UnrecognizedFlagError):
new_flags.undefined_flag = 0
def test_not_raise(self):
new_flags = _flagvalues.FlagValues()
new_flags._register_unknown_flag_setter(self.set_undef)
new_flags.undefined_flag = 0
self.assertEqual(self.setter_called, 1)
def test_not_raise_on_undefined_if_undefok(self):
new_flags = _flagvalues.FlagValues()
args = ['0', '--foo', '--bar=1', '--undefok=foo,bar']
unparsed = new_flags(args, known_only=True)
self.assertEqual(['0'], unparsed)
def test_re_raise_undefined(self):
def setter(unused_name, unused_val):
raise NameError()
new_flags = _flagvalues.FlagValues()
new_flags._register_unknown_flag_setter(setter)
with self.assertRaises(_exceptions.UnrecognizedFlagError):
new_flags.undefined_flag = 0
def test_re_raise_invalid(self):
def setter(unused_name, unused_val):
raise ValueError()
new_flags = _flagvalues.FlagValues()
new_flags._register_unknown_flag_setter(setter)
with self.assertRaises(_exceptions.IllegalFlagValueError):
new_flags.undefined_flag = 0
class SetAttributesTest(absltest.TestCase):
def setUp(self):
super(SetAttributesTest, self).setUp()
self.new_flags = _flagvalues.FlagValues()
_defines.DEFINE_boolean(
'defined_flag', None, '', flag_values=self.new_flags)
_defines.DEFINE_boolean(
'another_defined_flag', None, '', flag_values=self.new_flags)
self.setter_called = 0
def set_undef(self, unused_name, unused_val):
self.setter_called += 1
def test_two_defined_flags(self):
self.new_flags._set_attributes(
defined_flag=False, another_defined_flag=False)
self.assertEqual(self.setter_called, 0)
def test_one_defined_one_undefined_flag(self):
with self.assertRaises(_exceptions.UnrecognizedFlagError):
self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
def test_register_unknown_flag_setter(self):
self.new_flags._register_unknown_flag_setter(self.set_undef)
self.new_flags._set_attributes(defined_flag=False, undefined_flag=0)
self.assertEqual(self.setter_called, 1)
class FlagsDashSyntaxTest(absltest.TestCase):
def setUp(self):
super(FlagsDashSyntaxTest, self).setUp()
self.fv = _flagvalues.FlagValues()
_defines.DEFINE_string(
'long_name', 'default', 'help', flag_values=self.fv, short_name='s')
def test_long_name_one_dash(self):
self.fv(['./program', '-long_name=new'])
self.assertEqual('new', self.fv.long_name)
def test_long_name_two_dashes(self):
self.fv(['./program', '--long_name=new'])
self.assertEqual('new', self.fv.long_name)
def test_long_name_three_dashes(self):
with self.assertRaises(_exceptions.UnrecognizedFlagError):
self.fv(['./program', '---long_name=new'])
def test_short_name_one_dash(self):
self.fv(['./program', '-s=new'])
self.assertEqual('new', self.fv.s)
def test_short_name_two_dashes(self):
self.fv(['./program', '--s=new'])
self.assertEqual('new', self.fv.s)
def test_short_name_three_dashes(self):
with self.assertRaises(_exceptions.UnrecognizedFlagError):
self.fv(['./program', '---s=new'])
class UnparseFlagsTest(absltest.TestCase):
def test_using_default_value_none(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_string('default_none', None, 'help', flag_values=fv)
self.assertTrue(fv['default_none'].using_default_value)
fv(['', '--default_none=notNone'])
self.assertFalse(fv['default_none'].using_default_value)
fv.unparse_flags()
self.assertTrue(fv['default_none'].using_default_value)
fv(['', '--default_none=alsoNotNone'])
self.assertFalse(fv['default_none'].using_default_value)
fv.unparse_flags()
self.assertTrue(fv['default_none'].using_default_value)
def test_using_default_value_not_none(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_string('default_foo', 'foo', 'help', flag_values=fv)
fv.mark_as_parsed()
self.assertTrue(fv['default_foo'].using_default_value)
fv(['', '--default_foo=foo'])
self.assertFalse(fv['default_foo'].using_default_value)
fv(['', '--default_foo=notFoo'])
self.assertFalse(fv['default_foo'].using_default_value)
fv.unparse_flags()
self.assertTrue(fv['default_foo'].using_default_value)
fv(['', '--default_foo=alsoNotFoo'])
self.assertFalse(fv['default_foo'].using_default_value)
def test_allow_overwrite_false(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_string(
'default_none', None, 'help', allow_overwrite=False, flag_values=fv)
_defines.DEFINE_string(
'default_foo', 'foo', 'help', allow_overwrite=False, flag_values=fv)
fv.mark_as_parsed()
self.assertEqual('foo', fv.default_foo)
self.assertIsNone(fv.default_none)
fv(['', '--default_foo=notFoo', '--default_none=notNone'])
self.assertEqual('notFoo', fv.default_foo)
self.assertEqual('notNone', fv.default_none)
fv.unparse_flags()
self.assertEqual('foo', fv['default_foo'].value)
self.assertIsNone(fv['default_none'].value)
fv(['', '--default_foo=alsoNotFoo', '--default_none=alsoNotNone'])
self.assertEqual('alsoNotFoo', fv.default_foo)
self.assertEqual('alsoNotNone', fv.default_none)
def test_multi_string_default_none(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_multi_string('foo', None, 'help', flag_values=fv)
fv.mark_as_parsed()
self.assertIsNone(fv.foo)
fv(['', '--foo=aa'])
self.assertEqual(['aa'], fv.foo)
fv.unparse_flags()
self.assertIsNone(fv['foo'].value)
fv(['', '--foo=bb', '--foo=cc'])
self.assertEqual(['bb', 'cc'], fv.foo)
fv.unparse_flags()
self.assertIsNone(fv['foo'].value)
def test_multi_string_default_string(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_multi_string('foo', 'xyz', 'help', flag_values=fv)
expected_default = ['xyz']
fv.mark_as_parsed()
self.assertEqual(expected_default, fv.foo)
fv(['', '--foo=aa'])
self.assertEqual(['aa'], fv.foo)
fv.unparse_flags()
self.assertEqual(expected_default, fv['foo'].value)
fv(['', '--foo=bb', '--foo=cc'])
self.assertEqual(['bb', 'cc'], fv['foo'].value)
fv.unparse_flags()
self.assertEqual(expected_default, fv['foo'].value)
def test_multi_string_default_list(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_multi_string(
'foo', ['xx', 'yy', 'zz'], 'help', flag_values=fv)
expected_default = ['xx', 'yy', 'zz']
fv.mark_as_parsed()
self.assertEqual(expected_default, fv.foo)
fv(['', '--foo=aa'])
self.assertEqual(['aa'], fv.foo)
fv.unparse_flags()
self.assertEqual(expected_default, fv['foo'].value)
fv(['', '--foo=bb', '--foo=cc'])
self.assertEqual(['bb', 'cc'], fv.foo)
fv.unparse_flags()
self.assertEqual(expected_default, fv['foo'].value)
class UnparsedFlagAccessTest(absltest.TestCase):
def test_unparsed_flag_access(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_string('name', 'default', 'help', flag_values=fv)
with self.assertRaises(_exceptions.UnparsedFlagAccessError):
_ = fv.name
def test_hasattr_raises_in_py3(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_string('name', 'default', 'help', flag_values=fv)
with self.assertRaises(_exceptions.UnparsedFlagAccessError):
_ = hasattr(fv, 'name')
def test_unparsed_flags_access_raises_after_unparse_flags(self):
fv = _flagvalues.FlagValues()
_defines.DEFINE_string('a_str', 'default_value', 'help', flag_values=fv)
fv.mark_as_parsed()
self.assertEqual(fv.a_str, 'default_value')
fv.unparse_flags()
with self.assertRaises(_exceptions.UnparsedFlagAccessError):
_ = fv.a_str
class FlagHolderTest(absltest.TestCase):
def setUp(self):
super(FlagHolderTest, self).setUp()
self.fv = _flagvalues.FlagValues()
self.name_flag = _defines.DEFINE_string(
'name', 'default', 'help', flag_values=self.fv)
def parse_flags(self, *argv):
self.fv.unparse_flags()
self.fv(['binary_name'] + list(argv))
def test_name(self):
self.assertEqual('name', self.name_flag.name)
def test_value_before_flag_parsing(self):
with self.assertRaises(_exceptions.UnparsedFlagAccessError):
_ = self.name_flag.value
def test_value_returns_default_value_if_not_explicitly_set(self):
self.parse_flags()
self.assertEqual('default', self.name_flag.value)
def test_value_returns_explicitly_set_value(self):
self.parse_flags('--name=new_value')
self.assertEqual('new_value', self.name_flag.value)
def test_present_returns_false_before_flag_parsing(self):
self.assertFalse(self.name_flag.present)
def test_present_returns_false_if_not_explicitly_set(self):
self.parse_flags()
self.assertFalse(self.name_flag.present)
def test_present_returns_true_if_explicitly_set(self):
self.parse_flags('--name=new_value')
self.assertTrue(self.name_flag.present)
def test_serializes_flag(self):
self.parse_flags('--name=new_value')
self.assertEqual('--name=new_value', self.name_flag.serialize())
def test_allow_override(self):
first = _defines.DEFINE_integer(
'int_flag', 1, 'help', flag_values=self.fv, allow_override=1)
second = _defines.DEFINE_integer(
'int_flag', 2, 'help', flag_values=self.fv, allow_override=1)
self.parse_flags('--int_flag=3')
self.assertEqual(3, first.value)
self.assertEqual(3, second.value)
self.assertTrue(first.present)
self.assertTrue(second.present)
def test_eq(self):
with self.assertRaises(TypeError):
self.name_flag == 'value' # pylint: disable=pointless-statement
def test_eq_reflection(self):
with self.assertRaises(TypeError):
'value' == self.name_flag # pylint: disable=pointless-statement
def test_bool(self):
with self.assertRaises(TypeError):
bool(self.name_flag)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/_helpers_test.py 0000664 0000000 0000000 00000013732 14551576331 0022010 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unittests for helpers module."""
import sys
from absl.flags import _helpers
from absl.flags.tests import module_bar
from absl.flags.tests import module_foo
from absl.testing import absltest
class FlagSuggestionTest(absltest.TestCase):
def setUp(self):
self.longopts = [
'fsplit-ivs-in-unroller=',
'fsplit-wide-types=',
'fstack-protector=',
'fstack-protector-all=',
'fstrict-aliasing=',
'fstrict-overflow=',
'fthread-jumps=',
'ftracer',
'ftree-bit-ccp',
'ftree-builtin-call-dce',
'ftree-ccp',
'ftree-ch']
def test_damerau_levenshtein_id(self):
self.assertEqual(0, _helpers._damerau_levenshtein('asdf', 'asdf'))
def test_damerau_levenshtein_empty(self):
self.assertEqual(5, _helpers._damerau_levenshtein('', 'kites'))
self.assertEqual(6, _helpers._damerau_levenshtein('kitten', ''))
def test_damerau_levenshtein_commutative(self):
self.assertEqual(2, _helpers._damerau_levenshtein('kitten', 'kites'))
self.assertEqual(2, _helpers._damerau_levenshtein('kites', 'kitten'))
def test_damerau_levenshtein_transposition(self):
self.assertEqual(1, _helpers._damerau_levenshtein('kitten', 'ktiten'))
def test_mispelled_suggestions(self):
suggestions = _helpers.get_flag_suggestions('fstack_protector_all',
self.longopts)
self.assertEqual(['fstack-protector-all'], suggestions)
def test_ambiguous_prefix_suggestion(self):
suggestions = _helpers.get_flag_suggestions('fstack', self.longopts)
self.assertEqual(['fstack-protector', 'fstack-protector-all'], suggestions)
def test_misspelled_ambiguous_prefix_suggestion(self):
suggestions = _helpers.get_flag_suggestions('stack', self.longopts)
self.assertEqual(['fstack-protector', 'fstack-protector-all'], suggestions)
def test_crazy_suggestion(self):
suggestions = _helpers.get_flag_suggestions('asdfasdgasdfa', self.longopts)
self.assertEqual([], suggestions)
def test_suggestions_are_sorted(self):
sorted_flags = sorted(['aab', 'aac', 'aad'])
misspelt_flag = 'aaa'
suggestions = _helpers.get_flag_suggestions(
misspelt_flag, list(reversed(sorted_flags))
)
self.assertEqual(sorted_flags, suggestions)
class GetCallingModuleTest(absltest.TestCase):
"""Test whether we correctly determine the module which defines the flag."""
def test_get_calling_module(self):
self.assertEqual(_helpers.get_calling_module(), sys.argv[0])
self.assertEqual(module_foo.get_module_name(),
'absl.flags.tests.module_foo')
self.assertEqual(module_bar.get_module_name(),
'absl.flags.tests.module_bar')
# We execute the following exec statements for their side-effect
# (i.e., not raising an error). They emphasize the case that not
# all code resides in one of the imported modules: Python is a
# really dynamic language, where we can dynamically construct some
# code and execute it.
code = ('from absl.flags import _helpers\n'
'module_name = _helpers.get_calling_module()')
exec(code) # pylint: disable=exec-used
# Next two exec statements executes code with a global environment
# that is different from the global environment of any imported
# module.
exec(code, {}) # pylint: disable=exec-used
# vars(self) returns a dictionary corresponding to the symbol
# table of the self object. dict(...) makes a distinct copy of
# this dictionary, such that any new symbol definition by the
# exec-ed code (e.g., import flags, module_name = ...) does not
# affect the symbol table of self.
exec(code, dict(vars(self))) # pylint: disable=exec-used
# Next test is actually more involved: it checks not only that
# get_calling_module does not crash inside exec code, it also checks
# that it returns the expected value: the code executed via exec
# code is treated as being executed by the current module. We
# check it twice: first time by executing exec from the main
# module, second time by executing it from module_bar.
global_dict = {}
exec(code, global_dict) # pylint: disable=exec-used
self.assertEqual(global_dict['module_name'],
sys.argv[0])
global_dict = {}
module_bar.execute_code(code, global_dict)
self.assertEqual(global_dict['module_name'],
'absl.flags.tests.module_bar')
def test_get_calling_module_with_iteritems_error(self):
# This test checks that get_calling_module is using
# sys.modules.items(), instead of .iteritems().
orig_sys_modules = sys.modules
# Mock sys.modules: simulates error produced by importing a module
# in parallel with our iteration over sys.modules.iteritems().
class SysModulesMock(dict):
def __init__(self, original_content):
dict.__init__(self, original_content)
def iteritems(self):
# Any dictionary method is fine, but not .iteritems().
raise RuntimeError('dictionary changed size during iteration')
sys.modules = SysModulesMock(orig_sys_modules)
try:
# _get_calling_module should still work as expected:
self.assertEqual(_helpers.get_calling_module(), sys.argv[0])
self.assertEqual(module_foo.get_module_name(),
'absl.flags.tests.module_foo')
finally:
sys.modules = orig_sys_modules
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/_validators_test.py 0000664 0000000 0000000 00000110774 14551576331 0022522 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing that flags validators framework does work.
This file tests that each flag validator called when it should be, and that
failed validator will throw an exception, etc.
"""
import warnings
from absl.flags import _defines
from absl.flags import _exceptions
from absl.flags import _flagvalues
from absl.flags import _validators
from absl.testing import absltest
class SingleFlagValidatorTest(absltest.TestCase):
"""Testing _validators.register_validator() method."""
def setUp(self):
super(SingleFlagValidatorTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
self.call_args = []
def test_success(self):
def checker(x):
self.call_args.append(x)
return True
_defines.DEFINE_integer(
'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
_validators.register_validator(
'test_flag',
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program',)
self.flag_values(argv)
self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
def test_success_holder(self):
def checker(x):
self.call_args.append(x)
return True
flag_holder = _defines.DEFINE_integer(
'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
_validators.register_validator(
flag_holder,
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program',)
self.flag_values(argv)
self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
def test_success_holder_infer_flagvalues(self):
def checker(x):
self.call_args.append(x)
return True
flag_holder = _defines.DEFINE_integer(
'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
_validators.register_validator(
flag_holder,
checker,
message='Errors happen')
argv = ('./program',)
self.flag_values(argv)
self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
def test_default_value_not_used_success(self):
def checker(x):
self.call_args.append(x)
return True
_defines.DEFINE_integer(
'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
_validators.register_validator(
'test_flag',
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
self.flag_values(argv)
self.assertEqual(1, self.flag_values.test_flag)
self.assertEqual([1], self.call_args)
def test_validator_not_called_when_other_flag_is_changed(self):
def checker(x):
self.call_args.append(x)
return True
_defines.DEFINE_integer(
'test_flag', 1, 'Usual integer flag', flag_values=self.flag_values)
_defines.DEFINE_integer(
'other_flag', 2, 'Other integer flag', flag_values=self.flag_values)
_validators.register_validator(
'test_flag',
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program',)
self.flag_values(argv)
self.assertEqual(1, self.flag_values.test_flag)
self.flag_values.other_flag = 3
self.assertEqual([1], self.call_args)
def test_exception_raised_if_checker_fails(self):
def checker(x):
self.call_args.append(x)
return x == 1
_defines.DEFINE_integer(
'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
_validators.register_validator(
'test_flag',
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
self.flag_values(argv)
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.test_flag = 2
self.assertEqual('flag --test_flag=2: Errors happen', str(cm.exception))
self.assertEqual([1, 2], self.call_args)
def test_exception_raised_if_checker_raises_exception(self):
def checker(x):
self.call_args.append(x)
if x == 1:
return True
raise _exceptions.ValidationError('Specific message')
_defines.DEFINE_integer(
'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
_validators.register_validator(
'test_flag',
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
self.flag_values(argv)
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.test_flag = 2
self.assertEqual('flag --test_flag=2: Specific message', str(cm.exception))
self.assertEqual([1, 2], self.call_args)
def test_error_message_when_checker_returns_false_on_start(self):
def checker(x):
self.call_args.append(x)
return False
_defines.DEFINE_integer(
'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
_validators.register_validator(
'test_flag',
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
self.assertEqual('flag --test_flag=1: Errors happen', str(cm.exception))
self.assertEqual([1], self.call_args)
def test_error_message_when_checker_raises_exception_on_start(self):
def checker(x):
self.call_args.append(x)
raise _exceptions.ValidationError('Specific message')
_defines.DEFINE_integer(
'test_flag', None, 'Usual integer flag', flag_values=self.flag_values)
_validators.register_validator(
'test_flag',
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program', '--test_flag=1')
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
self.assertEqual('flag --test_flag=1: Specific message', str(cm.exception))
self.assertEqual([1], self.call_args)
def test_validators_checked_in_order(self):
def required(x):
self.calls.append('required')
return x is not None
def even(x):
self.calls.append('even')
return x % 2 == 0
self.calls = []
self._define_flag_and_validators(required, even)
self.assertEqual(['required', 'even'], self.calls)
self.calls = []
self._define_flag_and_validators(even, required)
self.assertEqual(['even', 'required'], self.calls)
def _define_flag_and_validators(self, first_validator, second_validator):
local_flags = _flagvalues.FlagValues()
_defines.DEFINE_integer(
'test_flag', 2, 'test flag', flag_values=local_flags)
_validators.register_validator(
'test_flag', first_validator, message='', flag_values=local_flags)
_validators.register_validator(
'test_flag', second_validator, message='', flag_values=local_flags)
argv = ('./program',)
local_flags(argv)
def test_validator_as_decorator(self):
_defines.DEFINE_integer(
'test_flag', None, 'Simple integer flag', flag_values=self.flag_values)
@_validators.validator('test_flag', flag_values=self.flag_values)
def checker(x):
self.call_args.append(x)
return True
argv = ('./program',)
self.flag_values(argv)
self.assertIsNone(self.flag_values.test_flag)
self.flag_values.test_flag = 2
self.assertEqual(2, self.flag_values.test_flag)
self.assertEqual([None, 2], self.call_args)
# Check that 'Checker' is still a function and has not been replaced.
self.assertTrue(checker(3))
self.assertEqual([None, 2, 3], self.call_args)
def test_mismatching_flagvalues(self):
def checker(x):
self.call_args.append(x)
return True
flag_holder = _defines.DEFINE_integer(
'test_flag',
None,
'Usual integer flag',
flag_values=_flagvalues.FlagValues())
expected = (
'flag_values must not be customized when operating on a FlagHolder')
with self.assertRaisesWithLiteralMatch(ValueError, expected):
_validators.register_validator(
flag_holder,
checker,
message='Errors happen',
flag_values=self.flag_values)
class MultiFlagsValidatorTest(absltest.TestCase):
"""Test flags multi-flag validators."""
def setUp(self):
super(MultiFlagsValidatorTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
self.call_args = []
self.foo_holder = _defines.DEFINE_integer(
'foo', 1, 'Usual integer flag', flag_values=self.flag_values)
self.bar_holder = _defines.DEFINE_integer(
'bar', 2, 'Usual integer flag', flag_values=self.flag_values)
def test_success(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
return True
_validators.register_multi_flags_validator(
['foo', 'bar'], checker, flag_values=self.flag_values)
argv = ('./program', '--bar=2')
self.flag_values(argv)
self.assertEqual(1, self.flag_values.foo)
self.assertEqual(2, self.flag_values.bar)
self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
self.flag_values.foo = 3
self.assertEqual(3, self.flag_values.foo)
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 2}],
self.call_args)
def test_success_holder(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
return True
_validators.register_multi_flags_validator(
[self.foo_holder, self.bar_holder],
checker,
flag_values=self.flag_values)
argv = ('./program', '--bar=2')
self.flag_values(argv)
self.assertEqual(1, self.flag_values.foo)
self.assertEqual(2, self.flag_values.bar)
self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
self.flag_values.foo = 3
self.assertEqual(3, self.flag_values.foo)
self.assertEqual([{
'foo': 1,
'bar': 2
}, {
'foo': 3,
'bar': 2
}], self.call_args)
def test_success_holder_infer_flagvalues(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
return True
_validators.register_multi_flags_validator(
[self.foo_holder, self.bar_holder], checker)
argv = ('./program', '--bar=2')
self.flag_values(argv)
self.assertEqual(1, self.flag_values.foo)
self.assertEqual(2, self.flag_values.bar)
self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
self.flag_values.foo = 3
self.assertEqual(3, self.flag_values.foo)
self.assertEqual([{
'foo': 1,
'bar': 2
}, {
'foo': 3,
'bar': 2
}], self.call_args)
def test_validator_not_called_when_other_flag_is_changed(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
return True
_defines.DEFINE_integer(
'other_flag', 3, 'Other integer flag', flag_values=self.flag_values)
_validators.register_multi_flags_validator(
['foo', 'bar'], checker, flag_values=self.flag_values)
argv = ('./program',)
self.flag_values(argv)
self.flag_values.other_flag = 3
self.assertEqual([{'foo': 1, 'bar': 2}], self.call_args)
def test_exception_raised_if_checker_fails(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
values = flags_dict.values()
# Make sure all the flags have different values.
return len(set(values)) == len(values)
_validators.register_multi_flags_validator(
['foo', 'bar'],
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program',)
self.flag_values(argv)
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
def test_exception_raised_if_checker_raises_exception(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
values = flags_dict.values()
# Make sure all the flags have different values.
if len(set(values)) != len(values):
raise _exceptions.ValidationError('Specific message')
return True
_validators.register_multi_flags_validator(
['foo', 'bar'],
checker,
message='Errors happen',
flag_values=self.flag_values)
argv = ('./program',)
self.flag_values(argv)
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
self.assertEqual('flags foo=1, bar=1: Specific message', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
def test_decorator(self):
@_validators.multi_flags_validator(
['foo', 'bar'], message='Errors happen', flag_values=self.flag_values)
def checker(flags_dict): # pylint: disable=unused-variable
self.call_args.append(flags_dict)
values = flags_dict.values()
# Make sure all the flags have different values.
return len(set(values)) == len(values)
argv = ('./program',)
self.flag_values(argv)
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.bar = 1
self.assertEqual('flags foo=1, bar=1: Errors happen', str(cm.exception))
self.assertEqual([{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}],
self.call_args)
def test_mismatching_flagvalues(self):
def checker(flags_dict):
self.call_args.append(flags_dict)
values = flags_dict.values()
# Make sure all the flags have different values.
return len(set(values)) == len(values)
other_holder = _defines.DEFINE_integer(
'other_flag',
3,
'Other integer flag',
flag_values=_flagvalues.FlagValues())
expected = (
'multiple FlagValues instances used in invocation. '
'FlagHolders must be registered to the same FlagValues instance as '
'do flag names, if provided.')
with self.assertRaisesWithLiteralMatch(ValueError, expected):
_validators.register_multi_flags_validator(
[self.foo_holder, self.bar_holder, other_holder],
checker,
message='Errors happen',
flag_values=self.flag_values)
class MarkFlagsAsMutualExclusiveTest(absltest.TestCase):
def setUp(self):
super(MarkFlagsAsMutualExclusiveTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
self.flag_one_holder = _defines.DEFINE_string(
'flag_one', None, 'flag one', flag_values=self.flag_values)
self.flag_two_holder = _defines.DEFINE_string(
'flag_two', None, 'flag two', flag_values=self.flag_values)
_defines.DEFINE_string(
'flag_three', None, 'flag three', flag_values=self.flag_values)
_defines.DEFINE_integer(
'int_flag_one', None, 'int flag one', flag_values=self.flag_values)
_defines.DEFINE_integer(
'int_flag_two', None, 'int flag two', flag_values=self.flag_values)
_defines.DEFINE_multi_string(
'multi_flag_one', None, 'multi flag one', flag_values=self.flag_values)
_defines.DEFINE_multi_string(
'multi_flag_two', None, 'multi flag two', flag_values=self.flag_values)
_defines.DEFINE_boolean(
'flag_not_none', False, 'false default', flag_values=self.flag_values)
def _mark_flags_as_mutually_exclusive(self, flag_names, required):
_validators.mark_flags_as_mutual_exclusive(
flag_names, required=required, flag_values=self.flag_values)
def test_no_flags_present(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
argv = ('./program',)
self.flag_values(argv)
self.assertIsNone(self.flag_values.flag_one)
self.assertIsNone(self.flag_values.flag_two)
def test_no_flags_present_holder(self):
self._mark_flags_as_mutually_exclusive(
[self.flag_one_holder, self.flag_two_holder], False)
argv = ('./program',)
self.flag_values(argv)
self.assertIsNone(self.flag_values.flag_one)
self.assertIsNone(self.flag_values.flag_two)
def test_no_flags_present_mixed(self):
self._mark_flags_as_mutually_exclusive([self.flag_one_holder, 'flag_two'],
False)
argv = ('./program',)
self.flag_values(argv)
self.assertIsNone(self.flag_values.flag_one)
self.assertIsNone(self.flag_values.flag_two)
def test_no_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
argv = ('./program',)
expected = (
'flags flag_one=None, flag_two=None: '
'Exactly one of (flag_one, flag_two) must have a value other than '
'None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
def test_one_flag_present(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], False)
self.flag_values(('./program', '--flag_one=1'))
self.assertEqual('1', self.flag_values.flag_one)
def test_one_flag_present_required(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
self.flag_values(('./program', '--flag_two=2'))
self.assertEqual('2', self.flag_values.flag_two)
def test_one_flag_zero_required(self):
self._mark_flags_as_mutually_exclusive(
['int_flag_one', 'int_flag_two'], True)
self.flag_values(('./program', '--int_flag_one=0'))
self.assertEqual(0, self.flag_values.int_flag_one)
def test_mutual_exclusion_with_extra_flags(self):
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_two'], True)
argv = ('./program', '--flag_two=2', '--flag_three=3')
self.flag_values(argv)
self.assertEqual('2', self.flag_values.flag_two)
self.assertEqual('3', self.flag_values.flag_three)
def test_mutual_exclusion_with_zero(self):
self._mark_flags_as_mutually_exclusive(
['int_flag_one', 'int_flag_two'], False)
argv = ('./program', '--int_flag_one=0', '--int_flag_two=0')
expected = (
'flags int_flag_one=0, int_flag_two=0: '
'At most one of (int_flag_one, int_flag_two) must have a value other '
'than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
def test_multiple_flags_present(self):
self._mark_flags_as_mutually_exclusive(
['flag_one', 'flag_two', 'flag_three'], False)
argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
expected = (
'flags flag_one=1, flag_two=2, flag_three=3: '
'At most one of (flag_one, flag_two, flag_three) must have a value '
'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
def test_multiple_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(
['flag_one', 'flag_two', 'flag_three'], True)
argv = ('./program', '--flag_one=1', '--flag_two=2', '--flag_three=3')
expected = (
'flags flag_one=1, flag_two=2, flag_three=3: '
'Exactly one of (flag_one, flag_two, flag_three) must have a value '
'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
def test_no_multiflags_present(self):
self._mark_flags_as_mutually_exclusive(
['multi_flag_one', 'multi_flag_two'], False)
argv = ('./program',)
self.flag_values(argv)
self.assertIsNone(self.flag_values.multi_flag_one)
self.assertIsNone(self.flag_values.multi_flag_two)
def test_no_multistring_flags_present_required(self):
self._mark_flags_as_mutually_exclusive(
['multi_flag_one', 'multi_flag_two'], True)
argv = ('./program',)
expected = (
'flags multi_flag_one=None, multi_flag_two=None: '
'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
def test_one_multiflag_present(self):
self._mark_flags_as_mutually_exclusive(
['multi_flag_one', 'multi_flag_two'], True)
self.flag_values(('./program', '--multi_flag_one=1'))
self.assertEqual(['1'], self.flag_values.multi_flag_one)
def test_one_multiflag_present_repeated(self):
self._mark_flags_as_mutually_exclusive(
['multi_flag_one', 'multi_flag_two'], True)
self.flag_values(('./program', '--multi_flag_one=1', '--multi_flag_one=1b'))
self.assertEqual(['1', '1b'], self.flag_values.multi_flag_one)
def test_multiple_multiflags_present(self):
self._mark_flags_as_mutually_exclusive(
['multi_flag_one', 'multi_flag_two'], False)
argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
expected = (
"flags multi_flag_one=['1'], multi_flag_two=['2']: "
'At most one of (multi_flag_one, multi_flag_two) must have a value '
'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
def test_multiple_multiflags_present_required(self):
self._mark_flags_as_mutually_exclusive(
['multi_flag_one', 'multi_flag_two'], True)
argv = ('./program', '--multi_flag_one=1', '--multi_flag_two=2')
expected = (
"flags multi_flag_one=['1'], multi_flag_two=['2']: "
'Exactly one of (multi_flag_one, multi_flag_two) must have a value '
'other than None.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
def test_flag_default_not_none_warning(self):
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter('always')
self._mark_flags_as_mutually_exclusive(['flag_one', 'flag_not_none'],
False)
self.assertLen(caught_warnings, 1)
self.assertIn('--flag_not_none has a non-None default value',
str(caught_warnings[0].message))
def test_multiple_flagvalues(self):
other_holder = _defines.DEFINE_boolean(
'other_flagvalues',
False,
'other ',
flag_values=_flagvalues.FlagValues())
expected = (
'multiple FlagValues instances used in invocation. '
'FlagHolders must be registered to the same FlagValues instance as '
'do flag names, if provided.')
with self.assertRaisesWithLiteralMatch(ValueError, expected):
self._mark_flags_as_mutually_exclusive(
[self.flag_one_holder, other_holder], False)
class MarkBoolFlagsAsMutualExclusiveTest(absltest.TestCase):
def setUp(self):
super(MarkBoolFlagsAsMutualExclusiveTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
self.false_1_holder = _defines.DEFINE_boolean(
'false_1', False, 'default false 1', flag_values=self.flag_values)
self.false_2_holder = _defines.DEFINE_boolean(
'false_2', False, 'default false 2', flag_values=self.flag_values)
self.true_1_holder = _defines.DEFINE_boolean(
'true_1', True, 'default true 1', flag_values=self.flag_values)
self.non_bool_holder = _defines.DEFINE_integer(
'non_bool', None, 'non bool', flag_values=self.flag_values)
def _mark_bool_flags_as_mutually_exclusive(self, flag_names, required):
_validators.mark_bool_flags_as_mutual_exclusive(
flag_names, required=required, flag_values=self.flag_values)
def test_no_flags_present(self):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
self.flag_values(('./program',))
self.assertEqual(False, self.flag_values.false_1)
self.assertEqual(False, self.flag_values.false_2)
def test_no_flags_present_holder(self):
self._mark_bool_flags_as_mutually_exclusive(
[self.false_1_holder, self.false_2_holder], False)
self.flag_values(('./program',))
self.assertEqual(False, self.flag_values.false_1)
self.assertEqual(False, self.flag_values.false_2)
def test_no_flags_present_mixed(self):
self._mark_bool_flags_as_mutually_exclusive(
[self.false_1_holder, 'false_2'], False)
self.flag_values(('./program',))
self.assertEqual(False, self.flag_values.false_1)
self.assertEqual(False, self.flag_values.false_2)
def test_no_flags_present_required(self):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], True)
argv = ('./program',)
expected = (
'flags false_1=False, false_2=False: '
'Exactly one of (false_1, false_2) must be True.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
def test_no_flags_present_with_default_true_required(self):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'true_1'], True)
self.flag_values(('./program',))
self.assertEqual(False, self.flag_values.false_1)
self.assertEqual(True, self.flag_values.true_1)
def test_two_flags_true(self):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'false_2'], False)
argv = ('./program', '--false_1', '--false_2')
expected = (
'flags false_1=True, false_2=True: At most one of (false_1, '
'false_2) must be True.')
self.assertRaisesWithLiteralMatch(_exceptions.IllegalFlagValueError,
expected, self.flag_values, argv)
def test_non_bool_flag(self):
expected = ('Flag --non_bool is not Boolean, which is required for flags '
'used in mark_bool_flags_as_mutual_exclusive.')
with self.assertRaisesWithLiteralMatch(_exceptions.ValidationError,
expected):
self._mark_bool_flags_as_mutually_exclusive(['false_1', 'non_bool'],
False)
def test_multiple_flagvalues(self):
other_bool_holder = _defines.DEFINE_boolean(
'other_bool', False, 'other bool', flag_values=_flagvalues.FlagValues())
expected = (
'multiple FlagValues instances used in invocation. '
'FlagHolders must be registered to the same FlagValues instance as '
'do flag names, if provided.')
with self.assertRaisesWithLiteralMatch(ValueError, expected):
self._mark_bool_flags_as_mutually_exclusive(
[self.false_1_holder, other_bool_holder], False)
class MarkFlagAsRequiredTest(absltest.TestCase):
def setUp(self):
super(MarkFlagAsRequiredTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
def test_success(self):
_defines.DEFINE_string(
'string_flag', None, 'string flag', flag_values=self.flag_values)
_validators.mark_flag_as_required(
'string_flag', flag_values=self.flag_values)
argv = ('./program', '--string_flag=value')
self.flag_values(argv)
self.assertEqual('value', self.flag_values.string_flag)
def test_success_holder(self):
holder = _defines.DEFINE_string(
'string_flag', None, 'string flag', flag_values=self.flag_values)
_validators.mark_flag_as_required(holder, flag_values=self.flag_values)
argv = ('./program', '--string_flag=value')
self.flag_values(argv)
self.assertEqual('value', self.flag_values.string_flag)
def test_success_holder_infer_flagvalues(self):
holder = _defines.DEFINE_string(
'string_flag', None, 'string flag', flag_values=self.flag_values)
_validators.mark_flag_as_required(holder)
argv = ('./program', '--string_flag=value')
self.flag_values(argv)
self.assertEqual('value', self.flag_values.string_flag)
def test_catch_none_as_default(self):
_defines.DEFINE_string(
'string_flag', None, 'string flag', flag_values=self.flag_values)
_validators.mark_flag_as_required(
'string_flag', flag_values=self.flag_values)
argv = ('./program',)
expected = (
r'flag --string_flag=None: Flag --string_flag must have a value other '
r'than None\.')
with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
self.flag_values(argv)
def test_catch_setting_none_after_program_start(self):
_defines.DEFINE_string(
'string_flag', 'value', 'string flag', flag_values=self.flag_values)
_validators.mark_flag_as_required(
'string_flag', flag_values=self.flag_values)
argv = ('./program',)
self.flag_values(argv)
self.assertEqual('value', self.flag_values.string_flag)
expected = ('flag --string_flag=None: Flag --string_flag must have a value '
'other than None.')
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.string_flag = None
self.assertEqual(expected, str(cm.exception))
def test_flag_default_not_none_warning(self):
_defines.DEFINE_string(
'flag_not_none', '', 'empty default', flag_values=self.flag_values)
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter('always')
_validators.mark_flag_as_required(
'flag_not_none', flag_values=self.flag_values)
self.assertLen(caught_warnings, 1)
self.assertIn('--flag_not_none has a non-None default value',
str(caught_warnings[0].message))
def test_mismatching_flagvalues(self):
flag_holder = _defines.DEFINE_string(
'string_flag',
'value',
'string flag',
flag_values=_flagvalues.FlagValues())
expected = (
'flag_values must not be customized when operating on a FlagHolder')
with self.assertRaisesWithLiteralMatch(ValueError, expected):
_validators.mark_flag_as_required(
flag_holder, flag_values=self.flag_values)
class MarkFlagsAsRequiredTest(absltest.TestCase):
def setUp(self):
super(MarkFlagsAsRequiredTest, self).setUp()
self.flag_values = _flagvalues.FlagValues()
def test_success(self):
_defines.DEFINE_string(
'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
_defines.DEFINE_string(
'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
flag_names = ['string_flag_1', 'string_flag_2']
_validators.mark_flags_as_required(flag_names, flag_values=self.flag_values)
argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
self.flag_values(argv)
self.assertEqual('value_1', self.flag_values.string_flag_1)
self.assertEqual('value_2', self.flag_values.string_flag_2)
def test_success_holders(self):
flag_1_holder = _defines.DEFINE_string(
'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
flag_2_holder = _defines.DEFINE_string(
'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
_validators.mark_flags_as_required([flag_1_holder, flag_2_holder],
flag_values=self.flag_values)
argv = ('./program', '--string_flag_1=value_1', '--string_flag_2=value_2')
self.flag_values(argv)
self.assertEqual('value_1', self.flag_values.string_flag_1)
self.assertEqual('value_2', self.flag_values.string_flag_2)
def test_catch_none_as_default(self):
_defines.DEFINE_string(
'string_flag_1', None, 'string flag 1', flag_values=self.flag_values)
_defines.DEFINE_string(
'string_flag_2', None, 'string flag 2', flag_values=self.flag_values)
_validators.mark_flags_as_required(
['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
argv = ('./program', '--string_flag_1=value_1')
expected = (
r'flag --string_flag_2=None: Flag --string_flag_2 must have a value '
r'other than None\.')
with self.assertRaisesRegex(_exceptions.IllegalFlagValueError, expected):
self.flag_values(argv)
def test_catch_setting_none_after_program_start(self):
_defines.DEFINE_string(
'string_flag_1',
'value_1',
'string flag 1',
flag_values=self.flag_values)
_defines.DEFINE_string(
'string_flag_2',
'value_2',
'string flag 2',
flag_values=self.flag_values)
_validators.mark_flags_as_required(
['string_flag_1', 'string_flag_2'], flag_values=self.flag_values)
argv = ('./program', '--string_flag_1=value_1')
self.flag_values(argv)
self.assertEqual('value_1', self.flag_values.string_flag_1)
expected = (
'flag --string_flag_1=None: Flag --string_flag_1 must have a value '
'other than None.')
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values.string_flag_1 = None
self.assertEqual(expected, str(cm.exception))
def test_catch_multiple_flags_as_none_at_program_start(self):
_defines.DEFINE_float(
'float_flag_1',
None,
'string flag 1',
flag_values=self.flag_values)
_defines.DEFINE_float(
'float_flag_2',
None,
'string flag 2',
flag_values=self.flag_values)
_validators.mark_flags_as_required(
['float_flag_1', 'float_flag_2'], flag_values=self.flag_values)
argv = ('./program', '')
expected = (
'flag --float_flag_1=None: Flag --float_flag_1 must have a value '
'other than None.\n'
'flag --float_flag_2=None: Flag --float_flag_2 must have a value '
'other than None.')
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
self.assertEqual(expected, str(cm.exception))
def test_fail_fast_single_flag_and_skip_remaining_validators(self):
def raise_unexpected_error(x):
del x
raise _exceptions.ValidationError('Should not be raised.')
_defines.DEFINE_float(
'flag_1', None, 'flag 1', flag_values=self.flag_values)
_defines.DEFINE_float(
'flag_2', 4.2, 'flag 2', flag_values=self.flag_values)
_validators.mark_flag_as_required('flag_1', flag_values=self.flag_values)
_validators.register_validator(
'flag_1', raise_unexpected_error, flag_values=self.flag_values)
_validators.register_multi_flags_validator(['flag_2', 'flag_1'],
raise_unexpected_error,
flag_values=self.flag_values)
argv = ('./program', '')
expected = (
'flag --flag_1=None: Flag --flag_1 must have a value other than None.')
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
self.assertEqual(expected, str(cm.exception))
def test_fail_fast_multi_flag_and_skip_remaining_validators(self):
def raise_expected_error(x):
del x
raise _exceptions.ValidationError('Expected error.')
def raise_unexpected_error(x):
del x
raise _exceptions.ValidationError('Got unexpected error.')
_defines.DEFINE_float(
'flag_1', 5.1, 'flag 1', flag_values=self.flag_values)
_defines.DEFINE_float(
'flag_2', 10.0, 'flag 2', flag_values=self.flag_values)
_validators.register_multi_flags_validator(['flag_1', 'flag_2'],
raise_expected_error,
flag_values=self.flag_values)
_validators.register_multi_flags_validator(['flag_2', 'flag_1'],
raise_unexpected_error,
flag_values=self.flag_values)
_validators.register_validator(
'flag_1', raise_unexpected_error, flag_values=self.flag_values)
_validators.register_validator(
'flag_2', raise_unexpected_error, flag_values=self.flag_values)
argv = ('./program', '')
expected = ('flags flag_1=5.1, flag_2=10.0: Expected error.')
with self.assertRaises(_exceptions.IllegalFlagValueError) as cm:
self.flag_values(argv)
self.assertEqual(expected, str(cm.exception))
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/argparse_flags_test.py 0000664 0000000 0000000 00000044521 14551576331 0023167 0 ustar 00root root 0000000 0000000 # Copyright 2018 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for absl.flags.argparse_flags."""
import io
import os
import subprocess
import sys
import tempfile
from unittest import mock
from absl import flags
from absl import logging
from absl.flags import argparse_flags
from absl.testing import _bazelize_command
from absl.testing import absltest
from absl.testing import parameterized
class ArgparseFlagsTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self._absl_flags = flags.FlagValues()
flags.DEFINE_bool(
'absl_bool', None, 'help for --absl_bool.',
short_name='b', flag_values=self._absl_flags)
# Add a boolean flag that starts with "no", to verify it can correctly
# handle the "no" prefixes in boolean flags.
flags.DEFINE_bool(
'notice', None, 'help for --notice.',
flag_values=self._absl_flags)
flags.DEFINE_string(
'absl_string', 'default', 'help for --absl_string=%.',
short_name='s', flag_values=self._absl_flags)
flags.DEFINE_integer(
'absl_integer', 1, 'help for --absl_integer.',
flag_values=self._absl_flags)
flags.DEFINE_float(
'absl_float', 1, 'help for --absl_integer.',
flag_values=self._absl_flags)
flags.DEFINE_enum(
'absl_enum', 'apple', ['apple', 'orange'], 'help for --absl_enum.',
flag_values=self._absl_flags)
def test_dash_as_prefix_char_only(self):
with self.assertRaises(ValueError):
argparse_flags.ArgumentParser(prefix_chars='/')
def test_default_inherited_absl_flags_value(self):
parser = argparse_flags.ArgumentParser()
self.assertIs(parser._inherited_absl_flags, flags.FLAGS)
def test_parse_absl_flags(self):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
self.assertFalse(self._absl_flags.is_parsed())
self.assertTrue(self._absl_flags['absl_string'].using_default_value)
self.assertTrue(self._absl_flags['absl_integer'].using_default_value)
self.assertTrue(self._absl_flags['absl_float'].using_default_value)
self.assertTrue(self._absl_flags['absl_enum'].using_default_value)
parser.parse_args(
['--absl_string=new_string', '--absl_integer', '2'])
self.assertEqual(self._absl_flags.absl_string, 'new_string')
self.assertEqual(self._absl_flags.absl_integer, 2)
self.assertTrue(self._absl_flags.is_parsed())
self.assertFalse(self._absl_flags['absl_string'].using_default_value)
self.assertFalse(self._absl_flags['absl_integer'].using_default_value)
self.assertTrue(self._absl_flags['absl_float'].using_default_value)
self.assertTrue(self._absl_flags['absl_enum'].using_default_value)
@parameterized.named_parameters(
('true', ['--absl_bool'], True),
('false', ['--noabsl_bool'], False),
('does_not_accept_equal_value', ['--absl_bool=true'], SystemExit),
('does_not_accept_space_value', ['--absl_bool', 'true'], SystemExit),
('long_name_single_dash', ['-absl_bool'], SystemExit),
('short_name', ['-b'], True),
('short_name_false', ['-nob'], SystemExit),
('short_name_double_dash', ['--b'], SystemExit),
('short_name_double_dash_false', ['--nob'], SystemExit),
)
def test_parse_boolean_flags(self, args, expected):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
self.assertIsNone(self._absl_flags['absl_bool'].value)
self.assertIsNone(self._absl_flags['b'].value)
if isinstance(expected, bool):
parser.parse_args(args)
self.assertEqual(expected, self._absl_flags.absl_bool)
self.assertEqual(expected, self._absl_flags.b)
else:
with self.assertRaises(expected):
parser.parse_args(args)
@parameterized.named_parameters(
('true', ['--notice'], True),
('false', ['--nonotice'], False),
)
def test_parse_boolean_existing_no_prefix(self, args, expected):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
self.assertIsNone(self._absl_flags['notice'].value)
parser.parse_args(args)
self.assertEqual(expected, self._absl_flags.notice)
def test_unrecognized_flag(self):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
with self.assertRaises(SystemExit):
parser.parse_args(['--unknown_flag=what'])
def test_absl_validators(self):
@flags.validator('absl_integer', flag_values=self._absl_flags)
def ensure_positive(value):
return value > 0
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
with self.assertRaises(SystemExit):
parser.parse_args(['--absl_integer', '-2'])
del ensure_positive
@parameterized.named_parameters(
('regular_name_double_dash', '--absl_string=new_string', 'new_string'),
('regular_name_single_dash', '-absl_string=new_string', SystemExit),
('short_name_double_dash', '--s=new_string', SystemExit),
('short_name_single_dash', '-s=new_string', 'new_string'),
)
def test_dashes(self, argument, expected):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
if isinstance(expected, str):
parser.parse_args([argument])
self.assertEqual(self._absl_flags.absl_string, expected)
else:
with self.assertRaises(expected):
parser.parse_args([argument])
def test_absl_flags_not_added_to_namespace(self):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
args = parser.parse_args(['--absl_string=new_string'])
self.assertIsNone(getattr(args, 'absl_string', None))
def test_mixed_flags_and_positional(self):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
parser.add_argument('--header', help='Header message to print.')
parser.add_argument('integers', metavar='N', type=int, nargs='+',
help='an integer for the accumulator')
args = parser.parse_args(
['--absl_string=new_string', '--header=HEADER', '--absl_integer',
'2', '3', '4'])
self.assertEqual(self._absl_flags.absl_string, 'new_string')
self.assertEqual(self._absl_flags.absl_integer, 2)
self.assertEqual(args.header, 'HEADER')
self.assertListEqual(args.integers, [3, 4])
def test_subparsers(self):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
parser.add_argument('--header', help='Header message to print.')
subparsers = parser.add_subparsers(help='The command to execute.')
# NOTE: The sub parsers don't work well with typing hence `type: ignore`.
# See https://github.com/python/typeshed/issues/10082.
sub_parser = subparsers.add_parser( # type: ignore
'sub_cmd', help='Sub command.', inherited_absl_flags=self._absl_flags
)
sub_parser.add_argument('--sub_flag', help='Sub command flag.')
def sub_command_func():
pass
sub_parser.set_defaults(command=sub_command_func)
args = parser.parse_args([
'--header=HEADER', '--absl_string=new_value', 'sub_cmd',
'--absl_integer=2', '--sub_flag=new_sub_flag_value'])
self.assertEqual(args.header, 'HEADER')
self.assertEqual(self._absl_flags.absl_string, 'new_value')
self.assertEqual(args.command, sub_command_func)
self.assertEqual(self._absl_flags.absl_integer, 2)
self.assertEqual(args.sub_flag, 'new_sub_flag_value')
def test_subparsers_no_inherit_in_subparser(self):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
subparsers = parser.add_subparsers(help='The command to execute.')
# NOTE: The sub parsers don't work well with typing hence `type: ignore`.
# See https://github.com/python/typeshed/issues/10082.
subparsers.add_parser( # type: ignore
'sub_cmd',
help='Sub command.',
# Do not inherit absl flags in the subparser.
# This is the behavior that this test exercises.
inherited_absl_flags=None,
)
with self.assertRaises(SystemExit):
parser.parse_args(['sub_cmd', '--absl_string=new_value'])
def test_help_main_module_flags(self):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
help_message = parser.format_help()
# Only the short name is shown in the usage string.
self.assertIn('[-s ABSL_STRING]', help_message)
# Both names are included in the options section.
self.assertIn('-s ABSL_STRING, --absl_string ABSL_STRING', help_message)
# Verify help messages.
self.assertIn('help for --absl_string=%.', help_message)
self.assertIn(': help for --absl_enum.', help_message)
def test_help_non_main_module_flags(self):
flags.DEFINE_string(
'non_main_module_flag', 'default', 'help',
module_name='other.module', flag_values=self._absl_flags)
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
help_message = parser.format_help()
# Non main module key flags are not printed in the help message.
self.assertNotIn('non_main_module_flag', help_message)
def test_help_non_main_module_key_flags(self):
flags.DEFINE_string(
'non_main_module_flag', 'default', 'help',
module_name='other.module', flag_values=self._absl_flags)
flags.declare_key_flag('non_main_module_flag', flag_values=self._absl_flags)
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
help_message = parser.format_help()
# Main module key fags are printed in the help message, even if the flag
# is defined in another module.
self.assertIn('non_main_module_flag', help_message)
@parameterized.named_parameters(
('h', ['-h']),
('help', ['--help']),
('helpshort', ['--helpshort']),
('helpfull', ['--helpfull']),
)
def test_help_flags(self, args):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
with self.assertRaises(SystemExit):
parser.parse_args(args)
@parameterized.named_parameters(
('h', ['-h']),
('help', ['--help']),
('helpshort', ['--helpshort']),
('helpfull', ['--helpfull']),
)
def test_no_help_flags(self, args):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags, add_help=False)
with mock.patch.object(parser, 'print_help') as print_help_mock:
with self.assertRaises(SystemExit):
parser.parse_args(args)
print_help_mock.assert_not_called()
def test_helpfull_message(self):
flags.DEFINE_string(
'non_main_module_flag', 'default', 'help',
module_name='other.module', flag_values=self._absl_flags)
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
with self.assertRaises(SystemExit),\
mock.patch.object(sys, 'stdout', new=io.StringIO()) as mock_stdout:
parser.parse_args(['--helpfull'])
stdout_message = mock_stdout.getvalue()
logging.info('captured stdout message:\n%s', stdout_message)
self.assertIn('--non_main_module_flag', stdout_message)
self.assertIn('other.module', stdout_message)
# Make sure the main module is not included.
self.assertNotIn(sys.argv[0], stdout_message)
# Special flags defined in absl.flags.
self.assertIn('absl.flags:', stdout_message)
self.assertIn('--flagfile', stdout_message)
self.assertIn('--undefok', stdout_message)
@parameterized.named_parameters(
('at_end',
('1', '--absl_string=value_from_cmd', '--flagfile='),
'value_from_file'),
('at_beginning',
('--flagfile=', '1', '--absl_string=value_from_cmd'),
'value_from_cmd'),
)
def test_flagfile(self, cmd_args, expected_absl_string_value):
# Set gnu_getopt to False, to verify it's ignored by argparse_flags.
self._absl_flags.set_gnu_getopt(False)
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
parser.add_argument('--header', help='Header message to print.')
parser.add_argument('integers', metavar='N', type=int, nargs='+',
help='an integer for the accumulator')
flagfile = tempfile.NamedTemporaryFile(
dir=absltest.TEST_TMPDIR.value, delete=False)
self.addCleanup(os.unlink, flagfile.name)
with flagfile:
flagfile.write(b'''
# The flag file.
--absl_string=value_from_file
--absl_integer=1
--header=header_from_file
''')
expand_flagfile = lambda x: x + flagfile.name if x == '--flagfile=' else x
cmd_args = [expand_flagfile(x) for x in cmd_args]
args = parser.parse_args(cmd_args)
self.assertEqual([1], args.integers)
self.assertEqual('header_from_file', args.header)
self.assertEqual(expected_absl_string_value, self._absl_flags.absl_string)
@parameterized.parameters(
('positional', {'positional'}, False),
('--not_existed', {'existed'}, False),
('--empty', set(), False),
('-single_dash', {'single_dash'}, True),
('--double_dash', {'double_dash'}, True),
('--with_value=value', {'with_value'}, True),
)
def test_is_undefok(self, arg, undefok_names, is_undefok):
self.assertEqual(is_undefok, argparse_flags._is_undefok(arg, undefok_names))
@parameterized.named_parameters(
('single', 'single', ['--single'], []),
('multiple', 'first,second', ['--first', '--second'], []),
('single_dash', 'dash', ['-dash'], []),
('mixed_dash', 'mixed', ['-mixed', '--mixed'], []),
('value', 'name', ['--name=value'], []),
('boolean_positive', 'bool', ['--bool'], []),
('boolean_negative', 'bool', ['--nobool'], []),
('left_over', 'strip', ['--first', '--strip', '--last'],
['--first', '--last']),
)
def test_strip_undefok_args(self, undefok, args, expected_args):
actual_args = argparse_flags._strip_undefok_args(undefok, args)
self.assertListEqual(expected_args, actual_args)
@parameterized.named_parameters(
('at_end', ['--unknown', '--undefok=unknown']),
('at_beginning', ['--undefok=unknown', '--unknown']),
('multiple', ['--unknown', '--undefok=unknown,another_unknown']),
('with_value', ['--unknown=value', '--undefok=unknown']),
('maybe_boolean', ['--nounknown', '--undefok=unknown']),
('with_space', ['--unknown', '--undefok', 'unknown']),
)
def test_undefok_flag_correct_use(self, cmd_args):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
args = parser.parse_args(cmd_args) # Make sure it doesn't raise.
# Make sure `undefok` is not exposed in namespace.
sentinel = object()
self.assertIs(sentinel, getattr(args, 'undefok', sentinel))
def test_undefok_flag_existing(self):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
parser.parse_args(
['--absl_string=new_value', '--undefok=absl_string'])
self.assertEqual('new_value', self._absl_flags.absl_string)
@parameterized.named_parameters(
('no_equal', ['--unknown', 'value', '--undefok=unknown']),
('single_dash', ['--unknown', '-undefok=unknown']),
)
def test_undefok_flag_incorrect_use(self, cmd_args):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags)
with self.assertRaises(SystemExit):
parser.parse_args(cmd_args)
def test_argument_default(self):
# Regression test for https://github.com/abseil/abseil-py/issues/171.
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=self._absl_flags, argument_default=23)
parser.add_argument(
'--magic_number', type=int, help='The magic number to use.')
args = parser.parse_args([])
self.assertEqual(args.magic_number, 23)
def test_empty_inherited_absl_flags(self):
parser = argparse_flags.ArgumentParser(
inherited_absl_flags=flags.FlagValues()
)
parser.add_argument('--foo')
flagfile = self.create_tempfile(content='--foo=bar').full_path
# Make sure these flags are still available when inheriting an empty
# FlagValues instance.
ns = parser.parse_args([
'--undefok=undefined_flag',
'--undefined_flag=value',
'--flagfile=' + flagfile,
])
self.assertEqual(ns.foo, 'bar')
class ArgparseWithAppRunTest(parameterized.TestCase):
@parameterized.named_parameters(
('simple',
'main_simple', 'parse_flags_simple',
['--argparse_echo=I am argparse.', '--absl_echo=I am absl.'],
['I am argparse.', 'I am absl.']),
('subcommand_roll_dice',
'main_subcommands', 'parse_flags_subcommands',
['--argparse_echo=I am argparse.', '--absl_echo=I am absl.',
'roll_dice', '--num_faces=12'],
['I am argparse.', 'I am absl.', 'Rolled a dice: ']),
('subcommand_shuffle',
'main_subcommands', 'parse_flags_subcommands',
['--argparse_echo=I am argparse.', '--absl_echo=I am absl.',
'shuffle', 'a', 'b', 'c'],
['I am argparse.', 'I am absl.', 'Shuffled: ']),
)
def test_argparse_with_app_run(
self, main_func_name, flags_parser_func_name, args, output_strings):
env = os.environ.copy()
env['MAIN_FUNC'] = main_func_name
env['FLAGS_PARSER_FUNC'] = flags_parser_func_name
helper = _bazelize_command.get_executable_path(
'absl/flags/tests/argparse_flags_test_helper')
try:
stdout = subprocess.check_output(
[helper] + args, env=env, universal_newlines=True)
except subprocess.CalledProcessError as e:
error_info = ('ERROR: argparse_helper failed\n'
'Command: {}\n'
'Exit code: {}\n'
'----- output -----\n{}'
'------------------')
error_info = error_info.format(e.cmd, e.returncode,
e.output + '\n' if e.output else '')
print(error_info, file=sys.stderr)
raise
for output_string in output_strings:
self.assertIn(output_string, stdout)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/argparse_flags_test_helper.py 0000664 0000000 0000000 00000005060 14551576331 0024521 0 ustar 00root root 0000000 0000000 # Copyright 2018 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test helper for argparse_flags_test."""
import os
import random
from absl import app
from absl import flags
from absl.flags import argparse_flags
FLAGS = flags.FLAGS
flags.DEFINE_string('absl_echo', None, 'The echo message from absl.flags.')
def parse_flags_simple(argv):
"""Simple example for absl.flags + argparse."""
parser = argparse_flags.ArgumentParser(
description='A simple example of argparse_flags.')
parser.add_argument(
'--argparse_echo', help='The echo message from argparse_flags')
return parser.parse_args(argv[1:])
def main_simple(args):
print('--absl_echo is', FLAGS.absl_echo)
print('--argparse_echo is', args.argparse_echo)
def roll_dice(args):
print('Rolled a dice:', random.randint(1, args.num_faces))
def shuffle(args):
inputs = list(args.inputs)
random.shuffle(inputs)
print('Shuffled:', ' '.join(inputs))
def parse_flags_subcommands(argv):
"""Subcommands example for absl.flags + argparse."""
parser = argparse_flags.ArgumentParser(
description='A subcommands example of argparse_flags.')
parser.add_argument('--argparse_echo',
help='The echo message from argparse_flags')
subparsers = parser.add_subparsers(help='The command to execute.')
roll_dice_parser = subparsers.add_parser(
'roll_dice', help='Roll a dice.')
roll_dice_parser.add_argument('--num_faces', type=int, default=6)
roll_dice_parser.set_defaults(command=roll_dice)
shuffle_parser = subparsers.add_parser(
'shuffle', help='Shuffle inputs.')
shuffle_parser.add_argument(
'inputs', metavar='I', nargs='+', help='Inputs to shuffle.')
shuffle_parser.set_defaults(command=shuffle)
return parser.parse_args(argv[1:])
def main_subcommands(args):
main_simple(args)
args.command(args)
if __name__ == '__main__':
main_func_name = os.environ['MAIN_FUNC']
flags_parser_func_name = os.environ['FLAGS_PARSER_FUNC']
app.run(main=globals()[main_func_name],
flags_parser=globals()[flags_parser_func_name])
abseil-py-2.1.0/absl/flags/tests/flags_formatting_test.py 0000664 0000000 0000000 00000021601 14551576331 0023527 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl import flags
from absl.flags import _helpers
from absl.testing import absltest
FLAGS = flags.FLAGS
class FlagsUnitTest(absltest.TestCase):
"""Flags formatting Unit Test."""
def test_get_help_width(self):
"""Verify that get_help_width() reflects _help_width."""
default_help_width = _helpers._DEFAULT_HELP_WIDTH # Save.
self.assertEqual(80, _helpers._DEFAULT_HELP_WIDTH)
self.assertEqual(_helpers._DEFAULT_HELP_WIDTH, flags.get_help_width())
_helpers._DEFAULT_HELP_WIDTH = 10
self.assertEqual(_helpers._DEFAULT_HELP_WIDTH, flags.get_help_width())
_helpers._DEFAULT_HELP_WIDTH = default_help_width # restore
def test_text_wrap(self):
"""Test that wrapping works as expected.
Also tests that it is using global flags._help_width by default.
"""
default_help_width = _helpers._DEFAULT_HELP_WIDTH
_helpers._DEFAULT_HELP_WIDTH = 10
# Generate a string with length 40, no spaces
text = ''
expect = []
for n in range(4):
line = str(n)
line += '123456789'
text += line
expect.append(line)
# Verify we still break
wrapped = flags.text_wrap(text).split('\n')
self.assertEqual(4, len(wrapped))
self.assertEqual(expect, wrapped)
wrapped = flags.text_wrap(text, 80).split('\n')
self.assertEqual(1, len(wrapped))
self.assertEqual([text], wrapped)
# Normal case, breaking at word boundaries and rewriting new lines
input_value = 'a b c d e f g h'
expect = {1: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'],
2: ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'],
3: ['a b', 'c d', 'e f', 'g h'],
4: ['a b', 'c d', 'e f', 'g h'],
5: ['a b c', 'd e f', 'g h'],
6: ['a b c', 'd e f', 'g h'],
7: ['a b c d', 'e f g h'],
8: ['a b c d', 'e f g h'],
9: ['a b c d e', 'f g h'],
10: ['a b c d e', 'f g h'],
11: ['a b c d e f', 'g h'],
12: ['a b c d e f', 'g h'],
13: ['a b c d e f g', 'h'],
14: ['a b c d e f g', 'h'],
15: ['a b c d e f g h']}
for width, exp in expect.items():
self.assertEqual(exp, flags.text_wrap(input_value, width).split('\n'))
# We turn lines with only whitespace into empty lines
# We strip from the right up to the first new line
self.assertEqual('', flags.text_wrap(' '))
self.assertEqual('\n', flags.text_wrap(' \n '))
self.assertEqual('\n', flags.text_wrap('\n\n'))
self.assertEqual('\n\n', flags.text_wrap('\n\n\n'))
self.assertEqual('\n', flags.text_wrap('\n '))
self.assertEqual('a\n\nb', flags.text_wrap('a\n \nb'))
self.assertEqual('a\n\n\nb', flags.text_wrap('a\n \n \nb'))
self.assertEqual('a\nb', flags.text_wrap(' a\nb '))
self.assertEqual('\na\nb', flags.text_wrap('\na\nb\n'))
self.assertEqual('\na\nb\n', flags.text_wrap(' \na\nb\n '))
self.assertEqual('\na\nb\n', flags.text_wrap(' \na\nb\n\n'))
# Double newline.
self.assertEqual('a\n\nb', flags.text_wrap(' a\n\n b'))
# We respect prefix
self.assertEqual(' a\n b\n c', flags.text_wrap('a\nb\nc', 80, ' '))
self.assertEqual('a\n b\n c', flags.text_wrap('a\nb\nc', 80, ' ', ''))
# tabs
self.assertEqual('a\n b c',
flags.text_wrap('a\nb\tc', 80, ' ', ''))
self.assertEqual('a\n bb c',
flags.text_wrap('a\nbb\tc', 80, ' ', ''))
self.assertEqual('a\n bbb c',
flags.text_wrap('a\nbbb\tc', 80, ' ', ''))
self.assertEqual('a\n bbbb c',
flags.text_wrap('a\nbbbb\tc', 80, ' ', ''))
self.assertEqual('a\n b\n c\n d',
flags.text_wrap('a\nb\tc\td', 3, ' ', ''))
self.assertEqual('a\n b\n c\n d',
flags.text_wrap('a\nb\tc\td', 4, ' ', ''))
self.assertEqual('a\n b\n c\n d',
flags.text_wrap('a\nb\tc\td', 5, ' ', ''))
self.assertEqual('a\n b c\n d',
flags.text_wrap('a\nb\tc\td', 6, ' ', ''))
self.assertEqual('a\n b c\n d',
flags.text_wrap('a\nb\tc\td', 7, ' ', ''))
self.assertEqual('a\n b c\n d',
flags.text_wrap('a\nb\tc\td', 8, ' ', ''))
self.assertEqual('a\n b c\n d',
flags.text_wrap('a\nb\tc\td', 9, ' ', ''))
self.assertEqual('a\n b c d',
flags.text_wrap('a\nb\tc\td', 10, ' ', ''))
# multiple tabs
self.assertEqual('a c',
flags.text_wrap('a\t\tc', 80, ' ', ''))
_helpers._DEFAULT_HELP_WIDTH = default_help_width # restore
def test_doc_to_help(self):
self.assertEqual('', flags.doc_to_help(' '))
self.assertEqual('', flags.doc_to_help(' \n '))
self.assertEqual('a\n\nb', flags.doc_to_help('a\n \nb'))
self.assertEqual('a\n\n\nb', flags.doc_to_help('a\n \n \nb'))
self.assertEqual('a b', flags.doc_to_help(' a\nb '))
self.assertEqual('a b', flags.doc_to_help('\na\nb\n'))
self.assertEqual('a\n\nb', flags.doc_to_help('\na\n\nb\n'))
self.assertEqual('a b', flags.doc_to_help(' \na\nb\n '))
# Different first line, one line empty - erm double new line.
self.assertEqual('a b c\n\nd', flags.doc_to_help('a\n b\n c\n\n d'))
self.assertEqual('a b\n c d', flags.doc_to_help('a\n b\n \tc\n d'))
self.assertEqual('a b\n c\n d',
flags.doc_to_help('a\n b\n \tc\n \td'))
def test_doc_to_help_flag_values(self):
# !!!!!!!!!!!!!!!!!!!!
# The following doc string is taken as is directly from flags.py:FlagValues
# The intention of this test is to verify 'live' performance
# !!!!!!!!!!!!!!!!!!!!
"""Used as a registry for 'Flag' objects.
A 'FlagValues' can then scan command line arguments, passing flag
arguments through to the 'Flag' objects that it owns. It also
provides easy access to the flag values. Typically only one
'FlagValues' object is needed by an application: flags.FLAGS
This class is heavily overloaded:
'Flag' objects are registered via __setitem__:
FLAGS['longname'] = x # register a new flag
The .value member of the registered 'Flag' objects can be accessed as
members of this 'FlagValues' object, through __getattr__. Both the
long and short name of the original 'Flag' objects can be used to
access its value:
FLAGS.longname # parsed flag value
FLAGS.x # parsed flag value (short name)
Command line arguments are scanned and passed to the registered 'Flag'
objects through the __call__ method. Unparsed arguments, including
argv[0] (e.g. the program name) are returned.
argv = FLAGS(sys.argv) # scan command line arguments
The original registered Flag objects can be retrieved through the use
"""
doc = flags.doc_to_help(self.test_doc_to_help_flag_values.__doc__)
# Test the general outline of the converted docs
lines = doc.splitlines()
self.assertEqual(17, len(lines))
empty_lines = [index for index in range(len(lines)) if not lines[index]]
self.assertEqual([1, 3, 5, 8, 12, 15], empty_lines)
# test that some starting prefix is kept
flags_lines = [index for index in range(len(lines))
if lines[index].startswith(' FLAGS')]
self.assertEqual([7, 10, 11], flags_lines)
# but other, especially common space has been removed
space_lines = [index for index in range(len(lines))
if lines[index] and lines[index][0].isspace()]
self.assertEqual([7, 10, 11, 14], space_lines)
# No right space was kept
rspace_lines = [index for index in range(len(lines))
if lines[index] != lines[index].rstrip()]
self.assertEqual([], rspace_lines)
# test double spaces are kept
self.assertEqual(True, lines[2].endswith('application: flags.FLAGS'))
def test_text_wrap_raises_on_excessive_indent(self):
"""Ensure an indent longer than line length raises."""
self.assertRaises(ValueError,
flags.text_wrap, 'dummy', length=10, indent=' ' * 10)
def test_text_wrap_raises_on_excessive_first_line(self):
"""Ensure a first line indent longer than line length raises."""
self.assertRaises(
ValueError,
flags.text_wrap, 'dummy', length=80, firstline_indent=' ' * 80)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/flags_helpxml_test.py 0000664 0000000 0000000 00000061107 14551576331 0023033 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for the XML-format help generated by the flags.py module."""
import enum
import io
import os
import string
import sys
import xml.dom.minidom
import xml.sax.saxutils
from absl import flags
from absl.flags import _helpers
from absl.flags.tests import module_bar
from absl.testing import absltest
class CreateXMLDOMElement(absltest.TestCase):
def _check(self, name, value, expected_output):
doc = xml.dom.minidom.Document()
node = _helpers.create_xml_dom_element(doc, name, value)
output = node.toprettyxml(' ', encoding='utf-8')
self.assertEqual(expected_output, output)
def test_create_xml_dom_element(self):
self._check('tag', '', b'\n')
self._check('tag', 'plain text', b'plain text\n')
self._check('tag', '(x < y) && (a >= b)',
b'(x < y) && (a >= b)\n')
# If the value is bytes with invalid unicode:
bytes_with_invalid_unicodes = b'\x81\xff'
# In python 3 the string representation is "b'\x81\xff'" so they are kept
# as "b'\x81\xff'".
self._check('tag', bytes_with_invalid_unicodes,
b"b'\\x81\\xff'\n")
# Some unicode chars are illegal in xml
# (http://www.w3.org/TR/REC-xml/#charsets):
self._check('tag', u'\x0b\x02\x08\ufffe', b'\n')
# Valid unicode will be encoded:
self._check('tag', u'\xff', b'\xc3\xbf\n')
def _list_separators_in_xmlformat(separators, indent=''):
"""Generates XML encoding of a list of list separators.
Args:
separators: A list of list separators. Usually, this should be a
string whose characters are the valid list separators, e.g., ','
means that both comma (',') and space (' ') are valid list
separators.
indent: A string that is added at the beginning of each generated
XML element.
Returns:
A string.
"""
result = ''
separators = list(separators)
separators.sort()
for sep_char in separators:
result += ('%s%s\n' %
(indent, repr(sep_char)))
return result
class FlagCreateXMLDOMElement(absltest.TestCase):
"""Test the create_xml_dom_element method for a single flag at a time.
There is one test* method for each kind of DEFINE_* declaration.
"""
def setUp(self):
# self.fv is a FlagValues object, just like flags.FLAGS. Each
# test registers one flag with this FlagValues.
self.fv = flags.FlagValues()
def _check_flag_help_in_xml(self, flag_name, module_name,
expected_output, is_key=False):
flag_obj = self.fv[flag_name]
doc = xml.dom.minidom.Document()
element = flag_obj._create_xml_dom_element(doc, module_name, is_key=is_key)
output = element.toprettyxml(indent=' ')
self.assertMultiLineEqual(expected_output, output)
def test_flag_help_in_xml_int(self):
flags.DEFINE_integer('index', 17, 'An integer flag', flag_values=self.fv)
expected_output_pattern = (
'\n'
' module.name\n'
' index\n'
' An integer flag\n'
' 17\n'
' %d\n'
' int\n'
'\n')
self._check_flag_help_in_xml('index', 'module.name',
expected_output_pattern % 17)
# Check that the output is correct even when the current value of
# a flag is different from the default one.
self.fv['index'].value = 20
self._check_flag_help_in_xml('index', 'module.name',
expected_output_pattern % 20)
def test_flag_help_in_xml_int_with_bounds(self):
flags.DEFINE_integer('nb_iters', 17, 'An integer flag',
lower_bound=5, upper_bound=27,
flag_values=self.fv)
expected_output = (
'\n'
' yes\n'
' module.name\n'
' nb_iters\n'
' An integer flag\n'
' 17\n'
' 17\n'
' int\n'
' 5\n'
' 27\n'
'\n')
self._check_flag_help_in_xml('nb_iters', 'module.name', expected_output,
is_key=True)
def test_flag_help_in_xml_string(self):
flags.DEFINE_string('file_path', '/path/to/my/dir', 'A test string flag.',
flag_values=self.fv)
expected_output = (
'\n'
' simple_module\n'
' file_path\n'
' A test string flag.\n'
' /path/to/my/dir\n'
' /path/to/my/dir\n'
' string\n'
'\n')
self._check_flag_help_in_xml('file_path', 'simple_module', expected_output)
def test_flag_help_in_xml_string_with_xmlillegal_chars(self):
flags.DEFINE_string('file_path', '/path/to/\x08my/dir',
'A test string flag.', flag_values=self.fv)
# '\x08' is not a legal character in XML 1.0 documents. Our
# current code purges such characters from the generated XML.
expected_output = (
'\n'
' simple_module\n'
' file_path\n'
' A test string flag.\n'
' /path/to/my/dir\n'
' /path/to/my/dir\n'
' string\n'
'\n')
self._check_flag_help_in_xml('file_path', 'simple_module', expected_output)
def test_flag_help_in_xml_boolean(self):
flags.DEFINE_boolean('use_gpu', False, 'Use gpu for performance.',
flag_values=self.fv)
expected_output = (
'\n'
' yes\n'
' a_module\n'
' use_gpu\n'
' Use gpu for performance.\n'
' false\n'
' false\n'
' bool\n'
'\n')
self._check_flag_help_in_xml('use_gpu', 'a_module', expected_output,
is_key=True)
def test_flag_help_in_xml_enum(self):
flags.DEFINE_enum('cc_version', 'stable', ['stable', 'experimental'],
'Compiler version to use.', flag_values=self.fv)
expected_output = (
'\n'
' tool\n'
' cc_version\n'
' <stable|experimental>: '
'Compiler version to use.\n'
' stable\n'
' stable\n'
' string enum\n'
' stable\n'
' experimental\n'
'\n')
self._check_flag_help_in_xml('cc_version', 'tool', expected_output)
def test_flag_help_in_xml_enum_class(self):
class Version(enum.Enum):
STABLE = 0
EXPERIMENTAL = 1
flags.DEFINE_enum_class('cc_version', 'STABLE', Version,
'Compiler version to use.', flag_values=self.fv)
expected_output = ('\n'
' tool\n'
' cc_version\n'
' <stable|experimental>: '
'Compiler version to use.\n'
' stable\n'
' Version.STABLE\n'
' enum class\n'
' STABLE\n'
' EXPERIMENTAL\n'
'\n')
self._check_flag_help_in_xml('cc_version', 'tool', expected_output)
def test_flag_help_in_xml_comma_separated_list(self):
flags.DEFINE_list('files', 'a.cc,a.h,archive/old.zip',
'Files to process.', flag_values=self.fv)
expected_output = (
'\n'
' tool\n'
' files\n'
' Files to process.\n'
' a.cc,a.h,archive/old.zip\n'
' [\'a.cc\', \'a.h\', \'archive/old.zip\']\n'
' comma separated list of strings\n'
' \',\'\n'
'\n')
self._check_flag_help_in_xml('files', 'tool', expected_output)
def test_list_as_default_argument_comma_separated_list(self):
flags.DEFINE_list('allow_users', ['alice', 'bob'],
'Users with access.', flag_values=self.fv)
expected_output = (
'\n'
' tool\n'
' allow_users\n'
' Users with access.\n'
' alice,bob\n'
' [\'alice\', \'bob\']\n'
' comma separated list of strings\n'
' \',\'\n'
'\n')
self._check_flag_help_in_xml('allow_users', 'tool', expected_output)
def test_none_as_default_arguments_comma_separated_list(self):
flags.DEFINE_list('allow_users', None,
'Users with access.', flag_values=self.fv)
expected_output = (
'\n'
' tool\n'
' allow_users\n'
' Users with access.\n'
' \n'
' None\n'
' comma separated list of strings\n'
' \',\'\n'
'\n')
self._check_flag_help_in_xml('allow_users', 'tool', expected_output)
def test_flag_help_in_xml_space_separated_list(self):
flags.DEFINE_spaceseplist('dirs', 'src libs bin',
'Directories to search.', flag_values=self.fv)
expected_separators = sorted(string.whitespace)
expected_output = (
'\n'
' tool\n'
' dirs\n'
' Directories to search.\n'
' src libs bin\n'
' [\'src\', \'libs\', \'bin\']\n'
' whitespace separated list of strings\n'
'LIST_SEPARATORS'
'\n').replace('LIST_SEPARATORS',
_list_separators_in_xmlformat(expected_separators,
indent=' '))
self._check_flag_help_in_xml('dirs', 'tool', expected_output)
def test_flag_help_in_xml_space_separated_list_with_comma_compat(self):
flags.DEFINE_spaceseplist('dirs', 'src libs,bin',
'Directories to search.', comma_compat=True,
flag_values=self.fv)
expected_separators = sorted(string.whitespace + ',')
expected_output = (
'\n'
' tool\n'
' dirs\n'
' Directories to search.\n'
' src libs bin\n'
' [\'src\', \'libs\', \'bin\']\n'
' whitespace or comma separated list of strings\n'
'LIST_SEPARATORS'
'\n').replace('LIST_SEPARATORS',
_list_separators_in_xmlformat(expected_separators,
indent=' '))
self._check_flag_help_in_xml('dirs', 'tool', expected_output)
def test_flag_help_in_xml_multi_string(self):
flags.DEFINE_multi_string('to_delete', ['a.cc', 'b.h'],
'Files to delete', flag_values=self.fv)
expected_output = (
'\n'
' tool\n'
' to_delete\n'
' Files to delete;\n'
' repeat this option to specify a list of values\n'
' [\'a.cc\', \'b.h\']\n'
' [\'a.cc\', \'b.h\']\n'
' multi string\n'
'\n')
self._check_flag_help_in_xml('to_delete', 'tool', expected_output)
def test_flag_help_in_xml_multi_int(self):
flags.DEFINE_multi_integer('cols', [5, 7, 23],
'Columns to select', flag_values=self.fv)
expected_output = (
'\n'
' tool\n'
' cols\n'
' Columns to select;\n '
'repeat this option to specify a list of values\n'
' [5, 7, 23]\n'
' [5, 7, 23]\n'
' multi int\n'
'\n')
self._check_flag_help_in_xml('cols', 'tool', expected_output)
def test_flag_help_in_xml_multi_enum(self):
flags.DEFINE_multi_enum('flavours', ['APPLE', 'BANANA'],
['APPLE', 'BANANA', 'CHERRY'],
'Compilation flavour.', flag_values=self.fv)
expected_output = (
'\n'
' tool\n'
' flavours\n'
' <APPLE|BANANA|CHERRY>: Compilation flavour.;\n'
' repeat this option to specify a list of values\n'
' [\'APPLE\', \'BANANA\']\n'
' [\'APPLE\', \'BANANA\']\n'
' multi string enum\n'
' APPLE\n'
' BANANA\n'
' CHERRY\n'
'\n')
self._check_flag_help_in_xml('flavours', 'tool', expected_output)
def test_flag_help_in_xml_multi_enum_class_singleton_default(self):
class Fruit(enum.Enum):
ORANGE = 0
BANANA = 1
flags.DEFINE_multi_enum_class('fruit', ['ORANGE'],
Fruit,
'The fruit flag.', flag_values=self.fv)
expected_output = (
'\n'
' tool\n'
' fruit\n'
' <orange|banana>: The fruit flag.;\n'
' repeat this option to specify a list of values\n'
' orange\n'
' orange\n'
' multi enum class\n'
' ORANGE\n'
' BANANA\n'
'\n')
self._check_flag_help_in_xml('fruit', 'tool', expected_output)
def test_flag_help_in_xml_multi_enum_class_list_default(self):
class Fruit(enum.Enum):
ORANGE = 0
BANANA = 1
flags.DEFINE_multi_enum_class('fruit', ['ORANGE', 'BANANA'],
Fruit,
'The fruit flag.', flag_values=self.fv)
expected_output = (
'\n'
' tool\n'
' fruit\n'
' <orange|banana>: The fruit flag.;\n'
' repeat this option to specify a list of values\n'
' orange,banana\n'
' orange,banana\n'
' multi enum class\n'
' ORANGE\n'
' BANANA\n'
'\n')
self._check_flag_help_in_xml('fruit', 'tool', expected_output)
# The next EXPECTED_HELP_XML_* constants are parts of a template for
# the expected XML output from WriteHelpInXMLFormatTest below. When
# we assemble these parts into a single big string, we'll take into
# account the ordering between the name of the main module and the
# name of module_bar. Next, we'll fill in the docstring for this
# module (%(usage_doc)s), the name of the main module
# (%(main_module_name)s) and the name of the module module_bar
# (%(module_bar_name)s). See WriteHelpInXMLFormatTest below.
EXPECTED_HELP_XML_START = """\
%(basename_of_argv0)s%(usage_doc)s
"""
EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE = """\
yes%(main_module_name)sallow_usersUsers with access.alice,bob['alice', 'bob']comma separated list of strings','yes%(main_module_name)scc_version<stable|experimental>: Compiler version to use.stablestablestring enumstableexperimentalyes%(main_module_name)scolsColumns to select;
repeat this option to specify a list of values[5, 7, 23][5, 7, 23]multi intyes%(main_module_name)sdirsDirectories to create.src libs bins['src', 'libs', 'bins']whitespace separated list of strings
%(whitespace_separators)s yes%(main_module_name)sfile_pathA test string flag./path/to/my/dir/path/to/my/dirstringyes%(main_module_name)sfilesFiles to process.a.cc,a.h,archive/old.zip['a.cc', 'a.h', 'archive/old.zip']comma separated list of strings\',\'yes%(main_module_name)sflavours<APPLE|BANANA|CHERRY>: Compilation flavour.;
repeat this option to specify a list of values['APPLE', 'BANANA']['APPLE', 'BANANA']multi string enumAPPLEBANANACHERRYyes%(main_module_name)sindexAn integer flag1717intyes%(main_module_name)snb_itersAn integer flag1717int527yes%(main_module_name)sto_deleteFiles to delete;
repeat this option to specify a list of values['a.cc', 'b.h']['a.cc', 'b.h']multi stringyes%(main_module_name)suse_gpuUse gpu for performance.falsefalsebool
"""
EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR = """\
%(module_bar_name)stmod_bar_tSample int flag.44intyes%(module_bar_name)stmod_bar_uSample int flag.55int%(module_bar_name)stmod_bar_vSample int flag.66int%(module_bar_name)stmod_bar_xBoolean flag.truetruebool%(module_bar_name)stmod_bar_yString flag.defaultdefaultstringyes%(module_bar_name)stmod_bar_zAnother boolean flag from module bar.falsefalsebool
"""
EXPECTED_HELP_XML_END = """\
"""
class WriteHelpInXMLFormatTest(absltest.TestCase):
"""Big test of FlagValues.write_help_in_xml_format, with several flags."""
def test_write_help_in_xmlformat(self):
fv = flags.FlagValues()
# Since these flags are defined by the top module, they are all key.
flags.DEFINE_integer('index', 17, 'An integer flag', flag_values=fv)
flags.DEFINE_integer('nb_iters', 17, 'An integer flag',
lower_bound=5, upper_bound=27, flag_values=fv)
flags.DEFINE_string('file_path', '/path/to/my/dir', 'A test string flag.',
flag_values=fv)
flags.DEFINE_boolean('use_gpu', False, 'Use gpu for performance.',
flag_values=fv)
flags.DEFINE_enum('cc_version', 'stable', ['stable', 'experimental'],
'Compiler version to use.', flag_values=fv)
flags.DEFINE_list('files', 'a.cc,a.h,archive/old.zip',
'Files to process.', flag_values=fv)
flags.DEFINE_list('allow_users', ['alice', 'bob'],
'Users with access.', flag_values=fv)
flags.DEFINE_spaceseplist('dirs', 'src libs bins',
'Directories to create.', flag_values=fv)
flags.DEFINE_multi_string('to_delete', ['a.cc', 'b.h'],
'Files to delete', flag_values=fv)
flags.DEFINE_multi_integer('cols', [5, 7, 23],
'Columns to select', flag_values=fv)
flags.DEFINE_multi_enum('flavours', ['APPLE', 'BANANA'],
['APPLE', 'BANANA', 'CHERRY'],
'Compilation flavour.', flag_values=fv)
# Define a few flags in a different module.
module_bar.define_flags(flag_values=fv)
# And declare only a few of them to be key. This way, we have
# different kinds of flags, defined in different modules, and not
# all of them are key flags.
flags.declare_key_flag('tmod_bar_z', flag_values=fv)
flags.declare_key_flag('tmod_bar_u', flag_values=fv)
# Generate flag help in XML format in the StringIO sio.
sio = io.StringIO()
fv.write_help_in_xml_format(sio)
# Check that we got the expected result.
expected_output_template = EXPECTED_HELP_XML_START
main_module_name = sys.argv[0]
module_bar_name = module_bar.__name__
if main_module_name < module_bar_name:
expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE
expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR
else:
expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MODULE_BAR
expected_output_template += EXPECTED_HELP_XML_FOR_FLAGS_FROM_MAIN_MODULE
expected_output_template += EXPECTED_HELP_XML_END
# XML representation of the whitespace list separators.
whitespace_separators = _list_separators_in_xmlformat(string.whitespace,
indent=' ')
expected_output = (
expected_output_template %
{'basename_of_argv0': os.path.basename(sys.argv[0]),
'usage_doc': sys.modules['__main__'].__doc__,
'main_module_name': main_module_name,
'module_bar_name': module_bar_name,
'whitespace_separators': whitespace_separators})
actual_output = sio.getvalue()
self.assertMultiLineEqual(expected_output, actual_output)
# Also check that our result is valid XML. minidom.parseString
# throws an xml.parsers.expat.ExpatError in case of an error.
xml.dom.minidom.parseString(actual_output)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/flags_numeric_bounds_test.py 0000664 0000000 0000000 00000007644 14551576331 0024404 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for lower/upper bounds validators for numeric flags."""
from unittest import mock
from absl import flags
from absl.flags import _validators
from absl.testing import absltest
class NumericFlagBoundsTest(absltest.TestCase):
def setUp(self):
super(NumericFlagBoundsTest, self).setUp()
self.flag_values = flags.FlagValues()
def test_no_validator_if_no_bounds(self):
"""Validator is not registered if lower and upper bound are None."""
with mock.patch.object(_validators, 'register_validator'
) as register_validator:
flags.DEFINE_integer('positive_flag', None, 'positive int',
lower_bound=0, flag_values=self.flag_values)
register_validator.assert_called_once_with(
'positive_flag', mock.ANY, flag_values=self.flag_values)
with mock.patch.object(_validators, 'register_validator'
) as register_validator:
flags.DEFINE_integer('int_flag', None, 'just int',
flag_values=self.flag_values)
register_validator.assert_not_called()
def test_success(self):
flags.DEFINE_integer('int_flag', 5, 'Just integer',
flag_values=self.flag_values)
argv = ('./program', '--int_flag=13')
self.flag_values(argv)
self.assertEqual(13, self.flag_values.int_flag)
self.flag_values.int_flag = 25
self.assertEqual(25, self.flag_values.int_flag)
def test_success_if_none(self):
flags.DEFINE_integer('int_flag', None, '',
lower_bound=0, upper_bound=5,
flag_values=self.flag_values)
argv = ('./program',)
self.flag_values(argv)
self.assertIsNone(self.flag_values.int_flag)
def test_success_if_exactly_equals(self):
flags.DEFINE_float('float_flag', None, '',
lower_bound=1, upper_bound=1,
flag_values=self.flag_values)
argv = ('./program', '--float_flag=1')
self.flag_values(argv)
self.assertEqual(1, self.flag_values.float_flag)
def test_exception_if_smaller(self):
flags.DEFINE_integer('int_flag', None, '',
lower_bound=0, upper_bound=5,
flag_values=self.flag_values)
argv = ('./program', '--int_flag=-1')
try:
self.flag_values(argv)
except flags.IllegalFlagValueError as e:
text = 'flag --int_flag=-1: -1 is not an integer in the range [0, 5]'
self.assertEqual(text, str(e))
class SettingFlagAfterStartTest(absltest.TestCase):
def setUp(self):
self.flag_values = flags.FlagValues()
def test_success(self):
flags.DEFINE_integer('int_flag', None, 'Just integer',
flag_values=self.flag_values)
argv = ('./program', '--int_flag=13')
self.flag_values(argv)
self.assertEqual(13, self.flag_values.int_flag)
self.flag_values.int_flag = 25
self.assertEqual(25, self.flag_values.int_flag)
def test_exception_if_setting_integer_flag_outside_bounds(self):
flags.DEFINE_integer('int_flag', None, 'Just integer', lower_bound=0,
flag_values=self.flag_values)
argv = ('./program', '--int_flag=13')
self.flag_values(argv)
self.assertEqual(13, self.flag_values.int_flag)
with self.assertRaises(flags.IllegalFlagValueError):
self.flag_values.int_flag = -2
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/flags_test.py 0000664 0000000 0000000 00000345554 14551576331 0021315 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for absl.flags used as a package."""
import contextlib
import enum
import io
import os
import shutil
import sys
import tempfile
import unittest
from absl import flags
from absl.flags import _exceptions
from absl.flags import _helpers
from absl.flags.tests import module_bar
from absl.flags.tests import module_baz
from absl.flags.tests import module_foo
from absl.testing import absltest
FLAGS = flags.FLAGS
@contextlib.contextmanager
def _use_gnu_getopt(flag_values, use_gnu_get_opt):
old_use_gnu_get_opt = flag_values.is_gnu_getopt()
flag_values.set_gnu_getopt(use_gnu_get_opt)
yield
flag_values.set_gnu_getopt(old_use_gnu_get_opt)
class FlagDictToArgsTest(absltest.TestCase):
def test_flatten_google_flag_map(self):
arg_dict = {
'week-end': None,
'estudia': False,
'trabaja': False,
'party': True,
'monday': 'party',
'score': 42,
'loadthatstuff': [42, 'hello', 'goodbye'],
}
self.assertSameElements(
('--week-end', '--noestudia', '--notrabaja', '--party',
'--monday=party', '--score=42', '--loadthatstuff=42,hello,goodbye'),
flags.flag_dict_to_args(arg_dict))
def test_flatten_google_flag_map_with_multi_flag(self):
arg_dict = {
'some_list': ['value1', 'value2'],
'some_multi_string': ['value3', 'value4'],
}
self.assertSameElements(
('--some_list=value1,value2', '--some_multi_string=value3',
'--some_multi_string=value4'),
flags.flag_dict_to_args(arg_dict, multi_flags={'some_multi_string'}))
class Fruit(enum.Enum):
APPLE = object()
ORANGE = object()
class CaseSensitiveFruit(enum.Enum):
apple = 1
orange = 2
APPLE = 3
class EmptyEnum(enum.Enum):
pass
class AliasFlagsTest(absltest.TestCase):
def setUp(self):
super(AliasFlagsTest, self).setUp()
self.flags = flags.FlagValues()
@property
def alias(self):
return self.flags['alias']
@property
def aliased(self):
return self.flags['aliased']
def define_alias(self, *args, **kwargs):
flags.DEFINE_alias(*args, flag_values=self.flags, **kwargs)
def define_integer(self, *args, **kwargs):
flags.DEFINE_integer(*args, flag_values=self.flags, **kwargs)
def define_multi_integer(self, *args, **kwargs):
flags.DEFINE_multi_integer(*args, flag_values=self.flags, **kwargs)
def define_string(self, *args, **kwargs):
flags.DEFINE_string(*args, flag_values=self.flags, **kwargs)
def assert_alias_mirrors_aliased(self, alias, aliased, ignore_due_to_bug=()):
# A few sanity checks to avoid false success
self.assertIn('FlagAlias', alias.__class__.__qualname__)
self.assertIsNot(alias, aliased)
self.assertNotEqual(aliased.name, alias.name)
alias_state = {}
aliased_state = {}
attrs = {
'allow_hide_cpp',
'allow_override',
'allow_override_cpp',
'allow_overwrite',
'allow_using_method_names',
'boolean',
'default',
'default_as_str',
'default_unparsed',
# TODO(rlevasseur): This should match, but a bug prevents it from being
# in sync.
# 'using_default_value',
'value',
}
attrs.difference_update(ignore_due_to_bug)
for attr in attrs:
alias_state[attr] = getattr(alias, attr)
aliased_state[attr] = getattr(aliased, attr)
self.assertEqual(aliased_state, alias_state, 'LHS is aliased; RHS is alias')
def test_serialize_multi(self):
self.define_multi_integer('aliased', [0, 1], '')
self.define_alias('alias', 'aliased')
actual = self.alias.serialize()
# TODO(rlevasseur): This should check for --alias=0\n--alias=1, but
# a bug causes it to serialize incorrectly.
self.assertEqual('--alias=[0, 1]', actual)
def test_allow_overwrite_false(self):
self.define_integer('aliased', None, 'help', allow_overwrite=False)
self.define_alias('alias', 'aliased')
with self.assertRaisesRegex(flags.IllegalFlagValueError, 'already defined'):
self.flags(['./program', '--alias=1', '--aliased=2'])
self.assertEqual(1, self.alias.value)
self.assertEqual(1, self.aliased.value)
def test_aliasing_multi_no_default(self):
def define_flags():
self.flags = flags.FlagValues()
self.define_multi_integer('aliased', None, 'help')
self.define_alias('alias', 'aliased')
with self.subTest('after defining'):
define_flags()
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
self.assertIsNone(self.alias.value)
with self.subTest('set alias'):
define_flags()
self.flags(['./program', '--alias=1', '--alias=2'])
self.assertEqual([1, 2], self.alias.value)
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
with self.subTest('set aliased'):
define_flags()
self.flags(['./program', '--aliased=1', '--aliased=2'])
self.assertEqual([1, 2], self.alias.value)
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
with self.subTest('not setting anything'):
define_flags()
self.flags(['./program'])
self.assertEqual(None, self.alias.value)
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
def test_aliasing_multi_with_default(self):
def define_flags():
self.flags = flags.FlagValues()
self.define_multi_integer('aliased', [0], 'help')
self.define_alias('alias', 'aliased')
with self.subTest('after defining'):
define_flags()
self.assertEqual([0], self.alias.default)
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
with self.subTest('set alias'):
define_flags()
self.flags(['./program', '--alias=1', '--alias=2'])
self.assertEqual([1, 2], self.alias.value)
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
self.assertEqual(2, self.alias.present)
# TODO(rlevasseur): This should assert 0, but a bug with aliases and
# MultiFlag causes the alias to increment aliased's present counter.
self.assertEqual(2, self.aliased.present)
with self.subTest('set aliased'):
define_flags()
self.flags(['./program', '--aliased=1', '--aliased=2'])
self.assertEqual([1, 2], self.alias.value)
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
self.assertEqual(0, self.alias.present)
# TODO(rlevasseur): This should assert 0, but a bug with aliases and
# MultiFlag causes the alias to increment aliased present counter.
self.assertEqual(2, self.aliased.present)
with self.subTest('not setting anything'):
define_flags()
self.flags(['./program'])
self.assertEqual([0], self.alias.value)
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
self.assertEqual(0, self.alias.present)
self.assertEqual(0, self.aliased.present)
def test_aliasing_regular(self):
def define_flags():
self.flags = flags.FlagValues()
self.define_string('aliased', '', 'help')
self.define_alias('alias', 'aliased')
define_flags()
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
self.flags(['./program', '--alias=1'])
self.assertEqual('1', self.alias.value)
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
self.assertEqual(1, self.alias.present)
self.assertEqual('--alias=1', self.alias.serialize())
self.assertEqual(1, self.aliased.present)
define_flags()
self.flags(['./program', '--aliased=2'])
self.assertEqual('2', self.alias.value)
self.assert_alias_mirrors_aliased(self.alias, self.aliased)
self.assertEqual(0, self.alias.present)
self.assertEqual('--alias=2', self.alias.serialize())
self.assertEqual(1, self.aliased.present)
def test_defining_alias_doesnt_affect_aliased_state_regular(self):
self.define_string('aliased', 'default', 'help')
self.define_alias('alias', 'aliased')
self.assertEqual(0, self.aliased.present)
self.assertEqual(0, self.alias.present)
def test_defining_alias_doesnt_affect_aliased_state_multi(self):
self.define_multi_integer('aliased', [0], 'help')
self.define_alias('alias', 'aliased')
self.assertEqual([0], self.aliased.value)
self.assertEqual([0], self.aliased.default)
self.assertEqual(0, self.aliased.present)
self.assertEqual([0], self.aliased.value)
self.assertEqual([0], self.aliased.default)
self.assertEqual(0, self.alias.present)
class FlagsUnitTest(absltest.TestCase):
"""Flags Unit Test."""
maxDiff = None
def test_flags(self):
"""Test normal usage with no (expected) errors."""
# Define flags
number_test_framework_flags = len(FLAGS)
repeat_help = 'how many times to repeat (0-5)'
flags.DEFINE_integer(
'repeat', 4, repeat_help, lower_bound=0, short_name='r')
flags.DEFINE_string('name', 'Bob', 'namehelp')
flags.DEFINE_boolean('debug', 0, 'debughelp')
flags.DEFINE_boolean('q', 1, 'quiet mode')
flags.DEFINE_boolean('quack', 0, "superstring of 'q'")
flags.DEFINE_boolean('noexec', 1, 'boolean flag with no as prefix')
flags.DEFINE_float('float', 3.14, 'using floats')
flags.DEFINE_integer('octal', '0o666', 'using octals')
flags.DEFINE_integer('decimal', '666', 'using decimals')
flags.DEFINE_integer('hexadecimal', '0x666', 'using hexadecimals')
flags.DEFINE_integer('x', 3, 'how eXtreme to be')
flags.DEFINE_integer('l', 0x7fffffff00000000, 'how long to be')
flags.DEFINE_list('args', 'v=1,"vmodule=a=0,b=2"', 'a list of arguments')
flags.DEFINE_list('letters', 'a,b,c', 'a list of letters')
flags.DEFINE_list(
'list_default_list',
['a', 'b', 'c'],
'with default being a list of strings',
)
flags.DEFINE_enum('kwery', None, ['who', 'what', 'Why', 'where', 'when'],
'?')
flags.DEFINE_enum(
'sense', None, ['Case', 'case', 'CASE'], '?', case_sensitive=True)
flags.DEFINE_enum(
'cases',
None, ['UPPER', 'lower', 'Initial', 'Ot_HeR'],
'?',
case_sensitive=False)
flags.DEFINE_enum(
'funny',
None, ['Joke', 'ha', 'ha', 'ha', 'ha'],
'?',
case_sensitive=True)
flags.DEFINE_enum(
'blah',
None, ['bla', 'Blah', 'BLAH', 'blah'],
'?',
case_sensitive=False)
flags.DEFINE_string(
'only_once', None, 'test only sets this once', allow_overwrite=False)
flags.DEFINE_string(
'universe',
None,
'test tries to set this three times',
allow_overwrite=False)
# Specify number of flags defined above. The short_name defined
# for 'repeat' counts as an extra flag.
number_defined_flags = 22 + 1
self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags)
self.assertEqual(FLAGS.repeat, 4)
self.assertEqual(FLAGS.name, 'Bob')
self.assertEqual(FLAGS.debug, 0)
self.assertEqual(FLAGS.q, 1)
self.assertEqual(FLAGS.octal, 0o666)
self.assertEqual(FLAGS.decimal, 666)
self.assertEqual(FLAGS.hexadecimal, 0x666)
self.assertEqual(FLAGS.x, 3)
self.assertEqual(FLAGS.l, 0x7fffffff00000000)
self.assertEqual(FLAGS.args, ['v=1', 'vmodule=a=0,b=2'])
self.assertEqual(FLAGS.letters, ['a', 'b', 'c'])
self.assertEqual(FLAGS.list_default_list, ['a', 'b', 'c'])
self.assertIsNone(FLAGS.kwery)
self.assertIsNone(FLAGS.sense)
self.assertIsNone(FLAGS.cases)
self.assertIsNone(FLAGS.funny)
self.assertIsNone(FLAGS.blah)
flag_values = FLAGS.flag_values_dict()
self.assertEqual(flag_values['repeat'], 4)
self.assertEqual(flag_values['name'], 'Bob')
self.assertEqual(flag_values['debug'], 0)
self.assertEqual(flag_values['r'], 4) # Short for repeat.
self.assertEqual(flag_values['q'], 1)
self.assertEqual(flag_values['quack'], 0)
self.assertEqual(flag_values['x'], 3)
self.assertEqual(flag_values['l'], 0x7fffffff00000000)
self.assertEqual(flag_values['args'], ['v=1', 'vmodule=a=0,b=2'])
self.assertEqual(flag_values['letters'], ['a', 'b', 'c'])
self.assertEqual(flag_values['list_default_list'], ['a', 'b', 'c'])
self.assertIsNone(flag_values['kwery'])
self.assertIsNone(flag_values['sense'])
self.assertIsNone(flag_values['cases'])
self.assertIsNone(flag_values['funny'])
self.assertIsNone(flag_values['blah'])
# Verify string form of defaults
self.assertEqual(FLAGS['repeat'].default_as_str, "'4'")
self.assertEqual(FLAGS['name'].default_as_str, "'Bob'")
self.assertEqual(FLAGS['debug'].default_as_str, "'false'")
self.assertEqual(FLAGS['q'].default_as_str, "'true'")
self.assertEqual(FLAGS['quack'].default_as_str, "'false'")
self.assertEqual(FLAGS['noexec'].default_as_str, "'true'")
self.assertEqual(FLAGS['x'].default_as_str, "'3'")
self.assertEqual(FLAGS['l'].default_as_str, "'9223372032559808512'")
self.assertEqual(FLAGS['args'].default_as_str, '\'v=1,"vmodule=a=0,b=2"\'')
self.assertEqual(FLAGS['letters'].default_as_str, "'a,b,c'")
self.assertEqual(FLAGS['list_default_list'].default_as_str, "'a,b,c'")
# Verify that the iterator for flags yields all the keys
keys = list(FLAGS)
keys.sort()
reg_flags = list(FLAGS._flags())
reg_flags.sort()
self.assertEqual(keys, reg_flags)
# Parse flags
# .. empty command line
argv = ('./program',)
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
# .. non-empty command line
argv = ('./program', '--debug', '--name=Bob', '-q', '--x=8')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['debug'].present, 1)
FLAGS['debug'].present = 0 # Reset
self.assertEqual(FLAGS['name'].present, 1)
FLAGS['name'].present = 0 # Reset
self.assertEqual(FLAGS['q'].present, 1)
FLAGS['q'].present = 0 # Reset
self.assertEqual(FLAGS['x'].present, 1)
FLAGS['x'].present = 0 # Reset
# Flags list.
self.assertLen(FLAGS, number_defined_flags + number_test_framework_flags)
self.assertIn('name', FLAGS)
self.assertIn('debug', FLAGS)
self.assertIn('repeat', FLAGS)
self.assertIn('r', FLAGS)
self.assertIn('q', FLAGS)
self.assertIn('quack', FLAGS)
self.assertIn('x', FLAGS)
self.assertIn('l', FLAGS)
self.assertIn('args', FLAGS)
self.assertIn('letters', FLAGS)
self.assertIn('list_default_list', FLAGS)
# __contains__
self.assertIn('name', FLAGS)
self.assertNotIn('name2', FLAGS)
# try deleting a flag
del FLAGS.r
self.assertLen(FLAGS,
number_defined_flags - 1 + number_test_framework_flags)
self.assertNotIn('r', FLAGS)
# .. command line with extra stuff
argv = ('./program', '--debug', '--name=Bob', 'extra')
argv = FLAGS(argv)
self.assertLen(argv, 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
self.assertEqual(FLAGS['debug'].present, 1)
FLAGS['debug'].present = 0 # Reset
self.assertEqual(FLAGS['name'].present, 1)
FLAGS['name'].present = 0 # Reset
# Test reset
argv = ('./program', '--debug')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['debug'].present, 1)
self.assertTrue(FLAGS['debug'].value)
FLAGS.unparse_flags()
self.assertEqual(FLAGS['debug'].present, 0)
self.assertFalse(FLAGS['debug'].value)
# Test that reset restores default value when default value is None.
argv = ('./program', '--kwery=who')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['kwery'].present, 1)
self.assertEqual(FLAGS['kwery'].value, 'who')
FLAGS.unparse_flags()
argv = ('./program', '--kwery=Why')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['kwery'].present, 1)
self.assertEqual(FLAGS['kwery'].value, 'Why')
FLAGS.unparse_flags()
self.assertEqual(FLAGS['kwery'].present, 0)
self.assertIsNone(FLAGS['kwery'].value)
# Test case sensitive enum.
argv = ('./program', '--sense=CASE')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['sense'].present, 1)
self.assertEqual(FLAGS['sense'].value, 'CASE')
FLAGS.unparse_flags()
argv = ('./program', '--sense=Case')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['sense'].present, 1)
self.assertEqual(FLAGS['sense'].value, 'Case')
FLAGS.unparse_flags()
# Test case insensitive enum.
argv = ('./program', '--cases=upper')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['cases'].present, 1)
self.assertEqual(FLAGS['cases'].value, 'UPPER')
FLAGS.unparse_flags()
# Test case sensitive enum with duplicates.
argv = ('./program', '--funny=ha')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['funny'].present, 1)
self.assertEqual(FLAGS['funny'].value, 'ha')
FLAGS.unparse_flags()
# Test case insensitive enum with duplicates.
argv = ('./program', '--blah=bLah')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['blah'].present, 1)
self.assertEqual(FLAGS['blah'].value, 'Blah')
FLAGS.unparse_flags()
argv = ('./program', '--blah=BLAH')
argv = FLAGS(argv)
self.assertLen(argv, 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(FLAGS['blah'].present, 1)
self.assertEqual(FLAGS['blah'].value, 'Blah')
FLAGS.unparse_flags()
# Test integer argument passing
argv = ('./program', '--x', '0x12345')
argv = FLAGS(argv)
self.assertEqual(FLAGS.x, 0x12345)
self.assertEqual(type(FLAGS.x), int)
argv = ('./program', '--x', '0x1234567890ABCDEF1234567890ABCDEF')
argv = FLAGS(argv)
self.assertEqual(FLAGS.x, 0x1234567890ABCDEF1234567890ABCDEF)
self.assertIsInstance(FLAGS.x, int)
argv = ('./program', '--x', '0o12345')
argv = FLAGS(argv)
self.assertEqual(FLAGS.x, 0o12345)
self.assertEqual(type(FLAGS.x), int)
# Treat 0-prefixed parameters as base-10, not base-8
argv = ('./program', '--x', '012345')
argv = FLAGS(argv)
self.assertEqual(FLAGS.x, 12345)
self.assertEqual(type(FLAGS.x), int)
argv = ('./program', '--x', '0123459')
argv = FLAGS(argv)
self.assertEqual(FLAGS.x, 123459)
self.assertEqual(type(FLAGS.x), int)
argv = ('./program', '--x', '0x123efg')
with self.assertRaises(flags.IllegalFlagValueError):
argv = FLAGS(argv)
# Test boolean argument parsing
flags.DEFINE_boolean('test0', None, 'test boolean parsing')
argv = ('./program', '--notest0')
argv = FLAGS(argv)
self.assertEqual(FLAGS.test0, 0)
flags.DEFINE_boolean('test1', None, 'test boolean parsing')
argv = ('./program', '--test1')
argv = FLAGS(argv)
self.assertEqual(FLAGS.test1, 1)
FLAGS.test0 = None
argv = ('./program', '--test0=false')
argv = FLAGS(argv)
self.assertEqual(FLAGS.test0, 0)
FLAGS.test1 = None
argv = ('./program', '--test1=true')
argv = FLAGS(argv)
self.assertEqual(FLAGS.test1, 1)
FLAGS.test0 = None
argv = ('./program', '--test0=0')
argv = FLAGS(argv)
self.assertEqual(FLAGS.test0, 0)
FLAGS.test1 = None
argv = ('./program', '--test1=1')
argv = FLAGS(argv)
self.assertEqual(FLAGS.test1, 1)
# Test booleans that already have 'no' as a prefix
FLAGS.noexec = None
argv = ('./program', '--nonoexec', '--name', 'Bob')
argv = FLAGS(argv)
self.assertEqual(FLAGS.noexec, 0)
FLAGS.noexec = None
argv = ('./program', '--name', 'Bob', '--noexec')
argv = FLAGS(argv)
self.assertEqual(FLAGS.noexec, 1)
# Test unassigned booleans
flags.DEFINE_boolean('testnone', None, 'test boolean parsing')
argv = ('./program',)
argv = FLAGS(argv)
self.assertIsNone(FLAGS.testnone)
# Test get with default
flags.DEFINE_boolean('testget1', None, 'test parsing with defaults')
flags.DEFINE_boolean('testget2', None, 'test parsing with defaults')
flags.DEFINE_boolean('testget3', None, 'test parsing with defaults')
flags.DEFINE_integer('testget4', None, 'test parsing with defaults')
argv = ('./program', '--testget1', '--notestget2')
argv = FLAGS(argv)
self.assertEqual(FLAGS.get_flag_value('testget1', 'foo'), 1)
self.assertEqual(FLAGS.get_flag_value('testget2', 'foo'), 0)
self.assertEqual(FLAGS.get_flag_value('testget3', 'foo'), 'foo')
self.assertEqual(FLAGS.get_flag_value('testget4', 'foo'), 'foo')
# test list code
lists = [['hello', 'moo', 'boo', '1'], []]
flags.DEFINE_list('testcomma_list', '', 'test comma list parsing')
flags.DEFINE_spaceseplist('testspace_list', '', 'tests space list parsing')
flags.DEFINE_spaceseplist(
'testspace_or_comma_list',
'',
'tests space list parsing with comma compatibility',
comma_compat=True)
for name, sep in (('testcomma_list', ','), ('testspace_list',
' '), ('testspace_list', '\n'),
('testspace_or_comma_list',
' '), ('testspace_or_comma_list',
'\n'), ('testspace_or_comma_list', ',')):
for lst in lists:
argv = ('./program', '--%s=%s' % (name, sep.join(lst)))
argv = FLAGS(argv)
self.assertEqual(getattr(FLAGS, name), lst)
# Test help text
flags_help = str(FLAGS)
self.assertNotEqual(
flags_help.find('repeat'), -1, 'cannot find flag in help')
self.assertNotEqual(
flags_help.find(repeat_help), -1, 'cannot find help string in help')
# Test flag specified twice
argv = ('./program', '--repeat=4', '--repeat=2', '--debug', '--nodebug')
argv = FLAGS(argv)
self.assertEqual(FLAGS.get_flag_value('repeat', None), 2)
self.assertEqual(FLAGS.get_flag_value('debug', None), 0)
# Test MultiFlag with single default value
flags.DEFINE_multi_string(
's_str',
'sing1',
'string option that can occur multiple times',
short_name='s')
self.assertEqual(FLAGS.get_flag_value('s_str', None), ['sing1'])
# Test MultiFlag with list of default values
multi_string_defs = ['def1', 'def2']
flags.DEFINE_multi_string(
'm_str',
multi_string_defs,
'string option that can occur multiple times',
short_name='m')
self.assertEqual(FLAGS.get_flag_value('m_str', None), multi_string_defs)
# Test flag specified multiple times with a MultiFlag
argv = ('./program', '--m_str=str1', '-m', 'str2')
argv = FLAGS(argv)
self.assertEqual(FLAGS.get_flag_value('m_str', None), ['str1', 'str2'])
# A flag with allow_overwrite set to False should behave normally when it
# is only specified once
argv = ('./program', '--only_once=singlevalue')
argv = FLAGS(argv)
self.assertEqual(FLAGS.get_flag_value('only_once', None), 'singlevalue')
# A flag with allow_overwrite set to False should complain when it is
# specified more than once
argv = ('./program', '--universe=ptolemaic', '--universe=copernicean',
'--universe=euclidean')
self.assertRaisesWithLiteralMatch(
flags.IllegalFlagValueError,
'flag --universe=copernicean: already defined as ptolemaic', FLAGS,
argv)
# A flag value error shouldn't modify the value:
flags.DEFINE_integer('smol', 1, 'smol flag', upper_bound=5)
with self.assertRaises(flags.IllegalFlagValueError):
FLAGS.smol = 6
self.assertEqual(FLAGS.smol, 1)
self.assertTrue(FLAGS['smol'].using_default_value)
# Test single-letter flags; should support both single and double dash
argv = ('./program', '-q')
argv = FLAGS(argv)
self.assertEqual(FLAGS.get_flag_value('q', None), 1)
argv = ('./program', '--q', '--x', '9', '--noquack')
argv = FLAGS(argv)
self.assertEqual(FLAGS.get_flag_value('q', None), 1)
self.assertEqual(FLAGS.get_flag_value('x', None), 9)
self.assertEqual(FLAGS.get_flag_value('quack', None), 0)
argv = ('./program', '--noq', '--x=10', '--quack')
argv = FLAGS(argv)
self.assertEqual(FLAGS.get_flag_value('q', None), 0)
self.assertEqual(FLAGS.get_flag_value('x', None), 10)
self.assertEqual(FLAGS.get_flag_value('quack', None), 1)
####################################
# Test flag serialization code:
old_testcomma_list = FLAGS.testcomma_list
old_testspace_list = FLAGS.testspace_list
old_testspace_or_comma_list = FLAGS.testspace_or_comma_list
argv = ('./program', FLAGS['test0'].serialize(), FLAGS['test1'].serialize(),
FLAGS['s_str'].serialize())
argv = FLAGS(argv)
self.assertEqual(FLAGS['test0'].serialize(), '--notest0')
self.assertEqual(FLAGS['test1'].serialize(), '--test1')
self.assertEqual(FLAGS['s_str'].serialize(), '--s_str=sing1')
self.assertEqual(FLAGS['testnone'].serialize(), '')
testcomma_list1 = ['aa', 'bb']
testspace_list1 = ['aa', 'bb', 'cc']
testspace_or_comma_list1 = ['aa', 'bb', 'cc', 'dd']
FLAGS.testcomma_list = list(testcomma_list1)
FLAGS.testspace_list = list(testspace_list1)
FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1)
argv = ('./program', FLAGS['testcomma_list'].serialize(),
FLAGS['testspace_list'].serialize(),
FLAGS['testspace_or_comma_list'].serialize())
argv = FLAGS(argv)
self.assertEqual(FLAGS.testcomma_list, testcomma_list1)
self.assertEqual(FLAGS.testspace_list, testspace_list1)
self.assertEqual(FLAGS.testspace_or_comma_list, testspace_or_comma_list1)
testcomma_list1 = ['aa some spaces', 'bb']
testspace_list1 = ['aa', 'bb,some,commas,', 'cc']
testspace_or_comma_list1 = ['aa', 'bb,some,commas,', 'cc']
FLAGS.testcomma_list = list(testcomma_list1)
FLAGS.testspace_list = list(testspace_list1)
FLAGS.testspace_or_comma_list = list(testspace_or_comma_list1)
argv = ('./program', FLAGS['testcomma_list'].serialize(),
FLAGS['testspace_list'].serialize(),
FLAGS['testspace_or_comma_list'].serialize())
argv = FLAGS(argv)
self.assertEqual(FLAGS.testcomma_list, testcomma_list1)
self.assertEqual(FLAGS.testspace_list, testspace_list1)
# We don't expect idempotency when commas are placed in an item value and
# comma_compat is enabled.
self.assertEqual(FLAGS.testspace_or_comma_list,
['aa', 'bb', 'some', 'commas', 'cc'])
FLAGS.testcomma_list = old_testcomma_list
FLAGS.testspace_list = old_testspace_list
FLAGS.testspace_or_comma_list = old_testspace_or_comma_list
####################################
# Test flag-update:
def args_list():
# Exclude flags that have different default values based on the
# environment.
flags_to_exclude = {'log_dir', 'test_srcdir', 'test_tmpdir'}
flagnames = set(FLAGS) - flags_to_exclude
nonbool_flags = []
truebool_flags = []
falsebool_flags = []
for name in flagnames:
flag_value = FLAGS.get_flag_value(name, None)
if not isinstance(FLAGS[name], flags.BooleanFlag):
nonbool_flags.append('--%s %s' % (name, flag_value))
elif flag_value:
truebool_flags.append('--%s' % name)
else:
falsebool_flags.append('--no%s' % name)
all_flags = nonbool_flags + truebool_flags + falsebool_flags
all_flags.sort()
return all_flags
argv = ('./program', '--repeat=3', '--name=giants', '--nodebug')
FLAGS(argv)
self.assertEqual(FLAGS.get_flag_value('repeat', None), 3)
self.assertEqual(FLAGS.get_flag_value('name', None), 'giants')
self.assertEqual(FLAGS.get_flag_value('debug', None), 0)
self.assertListEqual(
[
'--alsologtostderr',
"--args ['v=1', 'vmodule=a=0,b=2']",
'--blah None',
'--cases None',
'--decimal 666',
'--float 3.14',
'--funny None',
'--hexadecimal 1638',
'--kwery None',
'--l 9223372032559808512',
"--letters ['a', 'b', 'c']",
"--list_default_list ['a', 'b', 'c']",
'--logger_levels {}',
"--m ['str1', 'str2']",
"--m_str ['str1', 'str2']",
'--name giants',
'--no?',
'--nodebug',
'--noexec',
'--nohelp',
'--nohelpfull',
'--nohelpshort',
'--nohelpxml',
'--nologtostderr',
'--noonly_check_args',
'--nopdb_post_mortem',
'--noq',
'--norun_with_pdb',
'--norun_with_profiling',
'--notest0',
'--notestget2',
'--notestget3',
'--notestnone',
'--octal 438',
'--only_once singlevalue',
'--pdb False',
'--profile_file None',
'--quack',
'--repeat 3',
"--s ['sing1']",
"--s_str ['sing1']",
'--sense None',
'--showprefixforinfo',
'--smol 1',
'--stderrthreshold fatal',
'--test1',
'--test_random_seed 301',
'--test_randomize_ordering_seed ',
'--testcomma_list []',
'--testget1',
'--testget4 None',
'--testspace_list []',
'--testspace_or_comma_list []',
'--tmod_baz_x',
'--universe ptolemaic',
'--use_cprofile_for_profiling',
'--v -1',
'--verbosity -1',
'--x 10',
'--xml_output_file ',
],
args_list(),
)
argv = ('./program', '--debug', '--m_str=upd1', '-s', 'upd2')
FLAGS(argv)
self.assertEqual(FLAGS.get_flag_value('repeat', None), 3)
self.assertEqual(FLAGS.get_flag_value('name', None), 'giants')
self.assertEqual(FLAGS.get_flag_value('debug', None), 1)
# items appended to existing non-default value lists for --m/--m_str
# new value overwrites default value (not appended to it) for --s/--s_str
self.assertListEqual(
[
'--alsologtostderr',
"--args ['v=1', 'vmodule=a=0,b=2']",
'--blah None',
'--cases None',
'--debug',
'--decimal 666',
'--float 3.14',
'--funny None',
'--hexadecimal 1638',
'--kwery None',
'--l 9223372032559808512',
"--letters ['a', 'b', 'c']",
"--list_default_list ['a', 'b', 'c']",
'--logger_levels {}',
"--m ['str1', 'str2', 'upd1']",
"--m_str ['str1', 'str2', 'upd1']",
'--name giants',
'--no?',
'--noexec',
'--nohelp',
'--nohelpfull',
'--nohelpshort',
'--nohelpxml',
'--nologtostderr',
'--noonly_check_args',
'--nopdb_post_mortem',
'--noq',
'--norun_with_pdb',
'--norun_with_profiling',
'--notest0',
'--notestget2',
'--notestget3',
'--notestnone',
'--octal 438',
'--only_once singlevalue',
'--pdb False',
'--profile_file None',
'--quack',
'--repeat 3',
"--s ['sing1', 'upd2']",
"--s_str ['sing1', 'upd2']",
'--sense None',
'--showprefixforinfo',
'--smol 1',
'--stderrthreshold fatal',
'--test1',
'--test_random_seed 301',
'--test_randomize_ordering_seed ',
'--testcomma_list []',
'--testget1',
'--testget4 None',
'--testspace_list []',
'--testspace_or_comma_list []',
'--tmod_baz_x',
'--universe ptolemaic',
'--use_cprofile_for_profiling',
'--v -1',
'--verbosity -1',
'--x 10',
'--xml_output_file ',
],
args_list(),
)
####################################
# Test all kind of error conditions.
# Argument not in enum exception
argv = ('./program', '--kwery=WHEN')
self.assertRaises(flags.IllegalFlagValueError, FLAGS, argv)
argv = ('./program', '--kwery=why')
self.assertRaises(flags.IllegalFlagValueError, FLAGS, argv)
# Duplicate flag detection
with self.assertRaises(flags.DuplicateFlagError):
flags.DEFINE_boolean('run', 0, 'runhelp', short_name='q')
# Duplicate short flag detection
with self.assertRaisesRegex(
flags.DuplicateFlagError,
r"The flag 'z' is defined twice\. .*First from.*, Second from"):
flags.DEFINE_boolean('zoom1', 0, 'runhelp z1', short_name='z')
flags.DEFINE_boolean('zoom2', 0, 'runhelp z2', short_name='z')
raise AssertionError('duplicate short flag detection failed')
# Duplicate mixed flag detection
with self.assertRaisesRegex(
flags.DuplicateFlagError,
r"The flag 's' is defined twice\. .*First from.*, Second from"):
flags.DEFINE_boolean('short1', 0, 'runhelp s1', short_name='s')
flags.DEFINE_boolean('s', 0, 'runhelp s2')
# Check that duplicate flag detection detects definition sites
# correctly.
flagnames = ['repeated']
original_flags = flags.FlagValues()
flags.DEFINE_boolean(
flagnames[0],
False,
'Flag about to be repeated.',
flag_values=original_flags)
duplicate_flags = module_foo.duplicate_flags(flagnames)
with self.assertRaisesRegex(flags.DuplicateFlagError,
'flags_test.*module_foo'):
original_flags.append_flag_values(duplicate_flags)
# Make sure allow_override works
try:
flags.DEFINE_boolean(
'dup1', 0, 'runhelp d11', short_name='u', allow_override=0)
flag = FLAGS._flags()['dup1']
self.assertEqual(flag.default, 0)
flags.DEFINE_boolean(
'dup1', 1, 'runhelp d12', short_name='u', allow_override=1)
flag = FLAGS._flags()['dup1']
self.assertEqual(flag.default, 1)
except flags.DuplicateFlagError:
raise AssertionError('allow_override did not permit a flag duplication')
# Make sure allow_override works
try:
flags.DEFINE_boolean(
'dup2', 0, 'runhelp d21', short_name='u', allow_override=1)
flag = FLAGS._flags()['dup2']
self.assertEqual(flag.default, 0)
flags.DEFINE_boolean(
'dup2', 1, 'runhelp d22', short_name='u', allow_override=0)
flag = FLAGS._flags()['dup2']
self.assertEqual(flag.default, 1)
except flags.DuplicateFlagError:
raise AssertionError('allow_override did not permit a flag duplication')
# Make sure that re-importing a module does not cause a DuplicateFlagError
# to be raised.
try:
sys.modules.pop('absl.flags.tests.module_baz')
import absl.flags.tests.module_baz # pylint: disable=g-import-not-at-top
del absl
except flags.DuplicateFlagError:
raise AssertionError('Module reimport caused flag duplication error')
# Make sure that when we override, the help string gets updated correctly
flags.DEFINE_boolean(
'dup3', 0, 'runhelp d31', short_name='u', allow_override=1)
flags.DEFINE_boolean(
'dup3', 1, 'runhelp d32', short_name='u', allow_override=1)
self.assertEqual(str(FLAGS).find('runhelp d31'), -1)
self.assertNotEqual(str(FLAGS).find('runhelp d32'), -1)
# Make sure append_flag_values works
new_flags = flags.FlagValues()
flags.DEFINE_boolean('new1', 0, 'runhelp n1', flag_values=new_flags)
flags.DEFINE_boolean('new2', 0, 'runhelp n2', flag_values=new_flags)
self.assertEqual(len(new_flags._flags()), 2)
old_len = len(FLAGS._flags())
FLAGS.append_flag_values(new_flags)
self.assertEqual(len(FLAGS._flags()) - old_len, 2)
self.assertEqual('new1' in FLAGS._flags(), True)
self.assertEqual('new2' in FLAGS._flags(), True)
# Then test that removing those flags works
FLAGS.remove_flag_values(new_flags)
self.assertEqual(len(FLAGS._flags()), old_len)
self.assertFalse('new1' in FLAGS._flags())
self.assertFalse('new2' in FLAGS._flags())
# Make sure append_flag_values works with flags with shortnames.
new_flags = flags.FlagValues()
flags.DEFINE_boolean('new3', 0, 'runhelp n3', flag_values=new_flags)
flags.DEFINE_boolean(
'new4', 0, 'runhelp n4', flag_values=new_flags, short_name='n4')
self.assertEqual(len(new_flags._flags()), 3)
old_len = len(FLAGS._flags())
FLAGS.append_flag_values(new_flags)
self.assertEqual(len(FLAGS._flags()) - old_len, 3)
self.assertIn('new3', FLAGS._flags())
self.assertIn('new4', FLAGS._flags())
self.assertIn('n4', FLAGS._flags())
self.assertEqual(FLAGS._flags()['n4'], FLAGS._flags()['new4'])
# Then test removing them
FLAGS.remove_flag_values(new_flags)
self.assertEqual(len(FLAGS._flags()), old_len)
self.assertFalse('new3' in FLAGS._flags())
self.assertFalse('new4' in FLAGS._flags())
self.assertFalse('n4' in FLAGS._flags())
# Make sure append_flag_values fails on duplicates
flags.DEFINE_boolean('dup4', 0, 'runhelp d41')
new_flags = flags.FlagValues()
flags.DEFINE_boolean('dup4', 0, 'runhelp d42', flag_values=new_flags)
with self.assertRaises(flags.DuplicateFlagError):
FLAGS.append_flag_values(new_flags)
# Integer out of bounds
with self.assertRaises(flags.IllegalFlagValueError):
argv = ('./program', '--repeat=-4')
FLAGS(argv)
# Non-integer
with self.assertRaises(flags.IllegalFlagValueError):
argv = ('./program', '--repeat=2.5')
FLAGS(argv)
# Missing required argument
with self.assertRaises(flags.Error):
argv = ('./program', '--name')
FLAGS(argv)
# Non-boolean arguments for boolean
with self.assertRaises(flags.IllegalFlagValueError):
argv = ('./program', '--debug=goofup')
FLAGS(argv)
with self.assertRaises(flags.IllegalFlagValueError):
argv = ('./program', '--debug=42')
FLAGS(argv)
# Non-numeric argument for integer flag --repeat
with self.assertRaises(flags.IllegalFlagValueError):
argv = ('./program', '--repeat', 'Bob', 'extra')
FLAGS(argv)
# Aliases of existing flags
with self.assertRaises(flags.UnrecognizedFlagError):
flags.DEFINE_alias('alias_not_a_flag', 'not_a_flag')
# Programmtically modify alias and aliased flag
flags.DEFINE_alias('alias_octal', 'octal')
FLAGS.octal = 0o2222
self.assertEqual(0o2222, FLAGS.octal)
self.assertEqual(0o2222, FLAGS.alias_octal)
FLAGS.alias_octal = 0o4444
self.assertEqual(0o4444, FLAGS.octal)
self.assertEqual(0o4444, FLAGS.alias_octal)
# Setting alias preserves the default of the original
flags.DEFINE_alias('alias_name', 'name')
flags.DEFINE_alias('alias_debug', 'debug')
flags.DEFINE_alias('alias_decimal', 'decimal')
flags.DEFINE_alias('alias_float', 'float')
flags.DEFINE_alias('alias_letters', 'letters')
self.assertEqual(FLAGS['name'].default, FLAGS.alias_name)
self.assertEqual(FLAGS['debug'].default, FLAGS.alias_debug)
self.assertEqual(int(FLAGS['decimal'].default), FLAGS.alias_decimal)
self.assertEqual(float(FLAGS['float'].default), FLAGS.alias_float)
self.assertSameElements(FLAGS['letters'].default, FLAGS.alias_letters)
# Original flags set on command line
argv = ('./program', '--name=Martin', '--debug=True', '--decimal=777',
'--letters=x,y,z')
FLAGS(argv)
self.assertEqual('Martin', FLAGS.name)
self.assertEqual('Martin', FLAGS.alias_name)
self.assertTrue(FLAGS.debug)
self.assertTrue(FLAGS.alias_debug)
self.assertEqual(777, FLAGS.decimal)
self.assertEqual(777, FLAGS.alias_decimal)
self.assertSameElements(['x', 'y', 'z'], FLAGS.letters)
self.assertSameElements(['x', 'y', 'z'], FLAGS.alias_letters)
# Alias flags set on command line
argv = ('./program', '--alias_name=Auston', '--alias_debug=False',
'--alias_decimal=888', '--alias_letters=l,m,n')
FLAGS(argv)
self.assertEqual('Auston', FLAGS.name)
self.assertEqual('Auston', FLAGS.alias_name)
self.assertFalse(FLAGS.debug)
self.assertFalse(FLAGS.alias_debug)
self.assertEqual(888, FLAGS.decimal)
self.assertEqual(888, FLAGS.alias_decimal)
self.assertSameElements(['l', 'm', 'n'], FLAGS.letters)
self.assertSameElements(['l', 'm', 'n'], FLAGS.alias_letters)
# Make sure importing a module does not change flag value parsed
# from commandline.
flags.DEFINE_integer(
'dup5', 1, 'runhelp d51', short_name='u5', allow_override=0)
self.assertEqual(FLAGS.dup5, 1)
self.assertEqual(FLAGS.dup5, 1)
argv = ('./program', '--dup5=3')
FLAGS(argv)
self.assertEqual(FLAGS.dup5, 3)
flags.DEFINE_integer(
'dup5', 2, 'runhelp d52', short_name='u5', allow_override=1)
self.assertEqual(FLAGS.dup5, 3)
# Make sure importing a module does not change user defined flag value.
flags.DEFINE_integer(
'dup6', 1, 'runhelp d61', short_name='u6', allow_override=0)
self.assertEqual(FLAGS.dup6, 1)
FLAGS.dup6 = 3
self.assertEqual(FLAGS.dup6, 3)
flags.DEFINE_integer(
'dup6', 2, 'runhelp d62', short_name='u6', allow_override=1)
self.assertEqual(FLAGS.dup6, 3)
# Make sure importing a module does not change user defined flag value
# even if it is the 'default' value.
flags.DEFINE_integer(
'dup7', 1, 'runhelp d71', short_name='u7', allow_override=0)
self.assertEqual(FLAGS.dup7, 1)
FLAGS.dup7 = 1
self.assertEqual(FLAGS.dup7, 1)
flags.DEFINE_integer(
'dup7', 2, 'runhelp d72', short_name='u7', allow_override=1)
self.assertEqual(FLAGS.dup7, 1)
# Test module_help().
helpstr = FLAGS.module_help(module_baz)
expected_help = '\n' + module_baz.__name__ + ':' + """
--[no]tmod_baz_x: Boolean flag.
(default: 'true')"""
self.assertMultiLineEqual(expected_help, helpstr)
# Test main_module_help(). This must be part of test_flags because
# it depends on dup1/2/3/etc being introduced first.
helpstr = FLAGS.main_module_help()
expected_help = '\n' + sys.argv[0] + ':' + """
--[no]alias_debug: Alias for --debug.
(default: 'false')
--alias_decimal: Alias for --decimal.
(default: '666')
(an integer)
--alias_float: Alias for --float.
(default: '3.14')
(a number)
--alias_letters: Alias for --letters.
(default: 'a,b,c')
(a comma separated list)
--alias_name: Alias for --name.
(default: 'Bob')
--alias_octal: Alias for --octal.
(default: '438')
(an integer)
--args: a list of arguments
(default: 'v=1,"vmodule=a=0,b=2"')
(a comma separated list)
--blah: : ?
--cases: : ?
--[no]debug: debughelp
(default: 'false')
--decimal: using decimals
(default: '666')
(an integer)
-u,--[no]dup1: runhelp d12
(default: 'true')
-u,--[no]dup2: runhelp d22
(default: 'true')
-u,--[no]dup3: runhelp d32
(default: 'true')
--[no]dup4: runhelp d41
(default: 'false')
-u5,--dup5: runhelp d51
(default: '1')
(an integer)
-u6,--dup6: runhelp d61
(default: '1')
(an integer)
-u7,--dup7: runhelp d71
(default: '1')
(an integer)
--float: using floats
(default: '3.14')
(a number)
--funny: : ?
--hexadecimal: using hexadecimals
(default: '1638')
(an integer)
--kwery: : ?
--l: how long to be
(default: '9223372032559808512')
(an integer)
--letters: a list of letters
(default: 'a,b,c')
(a comma separated list)
--list_default_list: with default being a list of strings
(default: 'a,b,c')
(a comma separated list)
-m,--m_str: string option that can occur multiple times;
repeat this option to specify a list of values
(default: "['def1', 'def2']")
--name: namehelp
(default: 'Bob')
--[no]noexec: boolean flag with no as prefix
(default: 'true')
--octal: using octals
(default: '438')
(an integer)
--only_once: test only sets this once
--[no]q: quiet mode
(default: 'true')
--[no]quack: superstring of 'q'
(default: 'false')
-r,--repeat: how many times to repeat (0-5)
(default: '4')
(a non-negative integer)
-s,--s_str: string option that can occur multiple times;
repeat this option to specify a list of values
(default: "['sing1']")
--sense: : ?
--smol: smol flag
(default: '1')
(integer <= 5)
--[no]test0: test boolean parsing
--[no]test1: test boolean parsing
--testcomma_list: test comma list parsing
(default: '')
(a comma separated list)
--[no]testget1: test parsing with defaults
--[no]testget2: test parsing with defaults
--[no]testget3: test parsing with defaults
--testget4: test parsing with defaults
(an integer)
--[no]testnone: test boolean parsing
--testspace_list: tests space list parsing
(default: '')
(a whitespace separated list)
--testspace_or_comma_list: tests space list parsing with comma compatibility
(default: '')
(a whitespace or comma separated list)
--universe: test tries to set this three times
--x: how eXtreme to be
(default: '3')
(an integer)
-z,--[no]zoom1: runhelp z1
(default: 'false')"""
self.assertMultiLineEqual(expected_help, helpstr)
def test_string_flag_with_wrong_type(self):
fv = flags.FlagValues()
with self.assertRaises(flags.IllegalFlagValueError):
flags.DEFINE_string('name', False, 'help', flag_values=fv) # type: ignore
with self.assertRaises(flags.IllegalFlagValueError):
flags.DEFINE_string('name2', 0, 'help', flag_values=fv) # type: ignore
def test_integer_flag_with_wrong_type(self):
fv = flags.FlagValues()
with self.assertRaises(flags.IllegalFlagValueError):
flags.DEFINE_integer('name', 1e2, 'help', flag_values=fv) # type: ignore
with self.assertRaises(flags.IllegalFlagValueError):
flags.DEFINE_integer('name', [], 'help', flag_values=fv) # type: ignore
with self.assertRaises(flags.IllegalFlagValueError):
flags.DEFINE_integer('name', False, 'help', flag_values=fv)
def test_float_flag_with_wrong_type(self):
fv = flags.FlagValues()
with self.assertRaises(flags.IllegalFlagValueError):
flags.DEFINE_float('name', False, 'help', flag_values=fv)
def test_enum_flag_with_empty_values(self):
fv = flags.FlagValues()
with self.assertRaises(ValueError):
flags.DEFINE_enum('fruit', None, [], 'help', flag_values=fv)
def test_enum_flag_with_str_values(self):
fv = flags.FlagValues()
with self.assertRaises(ValueError):
flags.DEFINE_enum('fruit', None, 'option', 'help', flag_values=fv) # type: ignore
def test_multi_enum_flag_with_str_values(self):
fv = flags.FlagValues()
with self.assertRaises(ValueError):
flags.DEFINE_multi_enum('fruit', None, 'option', 'help', flag_values=fv) # type: ignore
def test_define_enum_class_flag(self):
fv = flags.FlagValues()
flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv)
fv.mark_as_parsed()
self.assertIsNone(fv.fruit)
def test_parse_enum_class_flag(self):
fv = flags.FlagValues()
flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv)
argv = ('./program', '--fruit=orange')
argv = fv(argv)
self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(fv['fruit'].present, 1)
self.assertEqual(fv['fruit'].value, Fruit.ORANGE)
fv.unparse_flags()
argv = ('./program', '--fruit=APPLE')
argv = fv(argv)
self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(fv['fruit'].present, 1)
self.assertEqual(fv['fruit'].value, Fruit.APPLE)
fv.unparse_flags()
def test_enum_class_flag_help_message(self):
fv = flags.FlagValues()
flags.DEFINE_enum_class('fruit', None, Fruit, '?', flag_values=fv)
helpstr = fv.main_module_help()
expected_help = '\n%s:\n --fruit: : ?' % sys.argv[0]
self.assertEqual(helpstr, expected_help)
def test_enum_class_flag_with_wrong_default_value_type(self):
fv = flags.FlagValues()
with self.assertRaises(_exceptions.IllegalFlagValueError):
flags.DEFINE_enum_class('fruit', 1, Fruit, 'help', flag_values=fv) # type: ignore
def test_enum_class_flag_requires_enum_class(self):
fv = flags.FlagValues()
with self.assertRaises(TypeError):
flags.DEFINE_enum_class( # type: ignore
'fruit', None, ['apple', 'orange'], 'help', flag_values=fv
)
def test_enum_class_flag_requires_non_empty_enum_class(self):
fv = flags.FlagValues()
with self.assertRaises(ValueError):
flags.DEFINE_enum_class('empty', None, EmptyEnum, 'help', flag_values=fv)
def test_required_flag(self):
fv = flags.FlagValues()
fl = flags.DEFINE_integer(
name='int_flag',
default=None,
help='help',
required=True,
flag_values=fv)
# Since the flag is required, the FlagHolder should ensure value returned
# is not None.
self.assertTrue(fl._ensure_non_none_value)
def test_illegal_required_flag(self):
fv = flags.FlagValues()
with self.assertRaises(ValueError):
flags.DEFINE_integer(
name='int_flag',
default=3,
help='help',
required=True,
flag_values=fv)
class MultiNumericalFlagsTest(absltest.TestCase):
def test_multi_numerical_flags(self):
"""Test multi_int and multi_float flags."""
fv = flags.FlagValues()
int_defaults = [77, 88]
flags.DEFINE_multi_integer(
'm_int',
int_defaults,
'integer option that can occur multiple times',
short_name='mi',
flag_values=fv)
self.assertListEqual(fv['m_int'].default, int_defaults)
argv = ('./program', '--m_int=-99', '--mi=101')
fv(argv)
self.assertListEqual(fv.get_flag_value('m_int', None), [-99, 101])
float_defaults = [2.2, 3]
flags.DEFINE_multi_float(
'm_float',
float_defaults,
'float option that can occur multiple times',
short_name='mf',
flag_values=fv)
for (expected, actual) in zip(float_defaults,
fv.get_flag_value('m_float', None)):
self.assertAlmostEqual(expected, actual)
argv = ('./program', '--m_float=-17', '--mf=2.78e9')
fv(argv)
expected_floats = [-17.0, 2.78e9]
for (expected, actual) in zip(expected_floats,
fv.get_flag_value('m_float', None)):
self.assertAlmostEqual(expected, actual)
def test_multi_numerical_with_tuples(self):
"""Verify multi_int/float accept tuples as default values."""
flags.DEFINE_multi_integer(
'm_int_tuple', (77, 88),
'integer option that can occur multiple times',
short_name='mi_tuple')
self.assertListEqual(FLAGS.get_flag_value('m_int_tuple', None), [77, 88])
dict_with_float_keys = {2.2: 'hello', 3: 'happy'}
float_defaults = dict_with_float_keys.keys()
flags.DEFINE_multi_float(
'm_float_tuple',
float_defaults,
'float option that can occur multiple times',
short_name='mf_tuple')
for (expected, actual) in zip(float_defaults,
FLAGS.get_flag_value('m_float_tuple', None)):
self.assertAlmostEqual(expected, actual)
def test_single_value_default(self):
"""Test multi_int and multi_float flags with a single default value."""
int_default = 77
flags.DEFINE_multi_integer('m_int1', int_default,
'integer option that can occur multiple times')
self.assertListEqual(FLAGS.get_flag_value('m_int1', None), [int_default])
float_default = 2.2
flags.DEFINE_multi_float('m_float1', float_default,
'float option that can occur multiple times')
actual = FLAGS.get_flag_value('m_float1', None)
self.assertEqual(1, len(actual))
self.assertAlmostEqual(actual[0], float_default)
def test_bad_multi_numerical_flags(self):
"""Test multi_int and multi_float flags with non-parseable values."""
# Test non-parseable defaults.
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r"flag --m_int2=abc: invalid literal for int\(\) with base 10: 'abc'",
flags.DEFINE_multi_integer, 'm_int2', ['abc'], 'desc')
self.assertRaisesRegex(
flags.IllegalFlagValueError, r'flag --m_float2=abc: '
r'(invalid literal for float\(\)|could not convert string to float): '
r"'?abc'?", flags.DEFINE_multi_float, 'm_float2', ['abc'], 'desc')
# Test non-parseable command line values.
fv = flags.FlagValues()
flags.DEFINE_multi_integer(
'm_int2',
'77',
'integer option that can occur multiple times',
flag_values=fv)
argv = ('./program', '--m_int2=def')
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r"flag --m_int2=def: invalid literal for int\(\) with base 10: 'def'",
fv, argv)
flags.DEFINE_multi_float(
'm_float2',
2.2,
'float option that can occur multiple times',
flag_values=fv)
argv = ('./program', '--m_float2=def')
self.assertRaisesRegex(
flags.IllegalFlagValueError, r'flag --m_float2=def: '
r'(invalid literal for float\(\)|could not convert string to float): '
r"'?def'?", fv, argv)
class MultiEnumFlagsTest(absltest.TestCase):
def test_multi_enum_flags(self):
"""Test multi_enum flags."""
fv = flags.FlagValues()
enum_defaults = ['FOO', 'BAZ']
flags.DEFINE_multi_enum(
'm_enum',
enum_defaults, ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
'Enum option that can occur multiple times',
short_name='me',
flag_values=fv)
self.assertListEqual(fv['m_enum'].default, enum_defaults)
argv = ('./program', '--m_enum=WHOOSH', '--me=FOO')
fv(argv)
self.assertListEqual(fv.get_flag_value('m_enum', None), ['WHOOSH', 'FOO'])
def test_help_text(self):
"""Test multi_enum flag's help text."""
fv = flags.FlagValues()
flags.DEFINE_multi_enum(
'm_enum',
None, ['FOO', 'BAR'],
'Enum option that can occur multiple times',
flag_values=fv)
self.assertRegex(
fv['m_enum'].help,
r': Enum option that can occur multiple times;\s+'
'repeat this option to specify a list of values')
def test_single_value_default(self):
"""Test multi_enum flags with a single default value."""
fv = flags.FlagValues()
enum_default = 'FOO'
flags.DEFINE_multi_enum(
'm_enum1',
enum_default, ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
'enum option that can occur multiple times',
flag_values=fv)
self.assertListEqual(fv['m_enum1'].default, [enum_default])
def test_case_sensitivity(self):
"""Test case sensitivity of multi_enum flag."""
fv = flags.FlagValues()
# Test case insensitive enum.
flags.DEFINE_multi_enum(
'm_enum2', ['whoosh'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
'Enum option that can occur multiple times',
short_name='me2',
case_sensitive=False,
flag_values=fv)
argv = ('./program', '--m_enum2=bar', '--me2=fOo')
fv(argv)
self.assertListEqual(fv.get_flag_value('m_enum2', None), ['BAR', 'FOO'])
# Test case sensitive enum.
flags.DEFINE_multi_enum(
'm_enum3', ['BAR'], ['FOO', 'BAR', 'BAZ', 'WHOOSH'],
'Enum option that can occur multiple times',
short_name='me3',
case_sensitive=True,
flag_values=fv)
argv = ('./program', '--m_enum3=bar', '--me3=fOo')
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r'flag --m_enum3=invalid: value should be one of ',
fv, argv)
def test_bad_multi_enum_flags(self):
"""Test multi_enum with invalid values."""
# Test defaults that are not in the permitted list of enums.
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r'flag --m_enum=INVALID: value should be one of ',
flags.DEFINE_multi_enum, 'm_enum', ['INVALID'], ['FOO', 'BAR', 'BAZ'],
'desc')
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r'flag --m_enum=1234: value should be one of ',
flags.DEFINE_multi_enum, 'm_enum2', [1234], ['FOO', 'BAR', 'BAZ'],
'desc')
# Test command-line values that are not in the permitted list of enums.
flags.DEFINE_multi_enum('m_enum4', 'FOO', ['FOO', 'BAR', 'BAZ'],
'enum option that can occur multiple times')
argv = ('./program', '--m_enum4=INVALID')
self.assertRaisesRegex(
flags.IllegalFlagValueError,
r'flag --m_enum4=invalid: value should be one of ', FLAGS,
argv)
class MultiEnumClassFlagsTest(absltest.TestCase):
def test_short_name(self):
fv = flags.FlagValues()
flags.DEFINE_multi_enum_class(
'fruit',
None,
Fruit,
'Enum option that can occur multiple times',
flag_values=fv,
short_name='me')
self.assertEqual(fv['fruit'].short_name, 'me')
def test_define_results_in_registered_flag_with_none(self):
fv = flags.FlagValues()
enum_defaults = None
flags.DEFINE_multi_enum_class(
'fruit',
enum_defaults,
Fruit,
'Enum option that can occur multiple times',
flag_values=fv)
fv.mark_as_parsed()
self.assertIsNone(fv.fruit)
def test_help_text(self):
fv = flags.FlagValues()
enum_defaults = None
flags.DEFINE_multi_enum_class(
'fruit',
enum_defaults,
Fruit,
'Enum option that can occur multiple times',
flag_values=fv)
self.assertRegex(
fv['fruit'].help,
r': Enum option that can occur multiple times;\s+'
'repeat this option to specify a list of values')
def test_define_results_in_registered_flag_with_string(self):
fv = flags.FlagValues()
enum_defaults = 'apple'
flags.DEFINE_multi_enum_class(
'fruit',
enum_defaults,
Fruit,
'Enum option that can occur multiple times',
flag_values=fv)
fv.mark_as_parsed()
self.assertListEqual(fv.fruit, [Fruit.APPLE])
def test_define_results_in_registered_flag_with_enum(self):
fv = flags.FlagValues()
enum_defaults = Fruit.APPLE
flags.DEFINE_multi_enum_class(
'fruit',
enum_defaults,
Fruit,
'Enum option that can occur multiple times',
flag_values=fv)
fv.mark_as_parsed()
self.assertListEqual(fv.fruit, [Fruit.APPLE])
def test_define_results_in_registered_flag_with_string_list(self):
fv = flags.FlagValues()
enum_defaults = ['apple', 'APPLE']
flags.DEFINE_multi_enum_class(
'fruit',
enum_defaults,
CaseSensitiveFruit,
'Enum option that can occur multiple times',
flag_values=fv,
case_sensitive=True)
fv.mark_as_parsed()
self.assertListEqual(fv.fruit,
[CaseSensitiveFruit.apple, CaseSensitiveFruit.APPLE])
def test_define_results_in_registered_flag_with_enum_list(self):
fv = flags.FlagValues()
enum_defaults = [Fruit.APPLE, Fruit.ORANGE]
flags.DEFINE_multi_enum_class(
'fruit',
enum_defaults,
Fruit,
'Enum option that can occur multiple times',
flag_values=fv)
fv.mark_as_parsed()
self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE])
def test_from_command_line_returns_multiple(self):
fv = flags.FlagValues()
enum_defaults = [Fruit.APPLE]
flags.DEFINE_multi_enum_class(
'fruit',
enum_defaults,
Fruit,
'Enum option that can occur multiple times',
flag_values=fv)
argv = ('./program', '--fruit=Apple', '--fruit=orange')
fv(argv)
self.assertListEqual(fv.fruit, [Fruit.APPLE, Fruit.ORANGE])
def test_bad_multi_enum_class_flags_from_definition(self):
with self.assertRaisesRegex(
flags.IllegalFlagValueError,
'flag --fruit=INVALID: value should be one of '):
flags.DEFINE_multi_enum_class('fruit', ['INVALID'], Fruit, 'desc')
def test_bad_multi_enum_class_flags_from_commandline(self):
fv = flags.FlagValues()
enum_defaults = [Fruit.APPLE]
flags.DEFINE_multi_enum_class(
'fruit', enum_defaults, Fruit, 'desc', flag_values=fv)
argv = ('./program', '--fruit=INVALID')
with self.assertRaisesRegex(
flags.IllegalFlagValueError,
'flag --fruit=INVALID: value should be one of '):
fv(argv)
class UnicodeFlagsTest(absltest.TestCase):
"""Testing proper unicode support for flags."""
def test_unicode_default_and_helpstring(self):
fv = flags.FlagValues()
flags.DEFINE_string(
'unicode_str',
b'\xC3\x80\xC3\xBD'.decode('utf-8'),
b'help:\xC3\xAA'.decode('utf-8'),
flag_values=fv)
argv = ('./program',)
fv(argv) # should not raise any exceptions
argv = ('./program', '--unicode_str=foo')
fv(argv) # should not raise any exceptions
def test_unicode_in_list(self):
fv = flags.FlagValues()
flags.DEFINE_list(
'unicode_list',
['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')],
b'help:\xC3\xAB'.decode('utf-8'),
flag_values=fv)
argv = ('./program',)
fv(argv) # should not raise any exceptions
argv = ('./program', '--unicode_list=hello,there')
fv(argv) # should not raise any exceptions
def test_xmloutput(self):
fv = flags.FlagValues()
flags.DEFINE_string(
'unicode1',
b'\xC3\x80\xC3\xBD'.decode('utf-8'),
b'help:\xC3\xAC'.decode('utf-8'),
flag_values=fv)
flags.DEFINE_list(
'unicode2',
['abc', b'\xC3\x80'.decode('utf-8'), b'\xC3\xBD'.decode('utf-8')],
b'help:\xC3\xAD'.decode('utf-8'),
flag_values=fv)
flags.DEFINE_list(
'non_unicode', ['abc', 'def', 'ghi'],
b'help:\xC3\xAD'.decode('utf-8'),
flag_values=fv)
outfile = io.StringIO()
fv.write_help_in_xml_format(outfile)
actual_output = outfile.getvalue()
# The xml output is large, so we just check parts of it.
self.assertIn(
b'unicode1\n'
b' help:\xc3\xac\n'
b' \xc3\x80\xc3\xbd\n'
b' \xc3\x80\xc3\xbd'.decode('utf-8'),
actual_output)
self.assertIn(
b'unicode2\n'
b' help:\xc3\xad\n'
b' abc,\xc3\x80,\xc3\xbd\n'
b" ['abc', '\xc3\x80', '\xc3\xbd']"
b''.decode('utf-8'), actual_output)
self.assertIn(
b'non_unicode\n'
b' help:\xc3\xad\n'
b' abc,def,ghi\n'
b" ['abc', 'def', 'ghi']"
b''.decode('utf-8'), actual_output)
class LoadFromFlagFileTest(absltest.TestCase):
"""Testing loading flags from a file and parsing them."""
def setUp(self):
self.flag_values = flags.FlagValues()
flags.DEFINE_string(
'unittest_message1',
'Foo!',
'You Add Here.',
flag_values=self.flag_values)
flags.DEFINE_string(
'unittest_message2',
'Bar!',
'Hello, Sailor!',
flag_values=self.flag_values)
flags.DEFINE_boolean(
'unittest_boolflag',
0,
'Some Boolean thing',
flag_values=self.flag_values)
flags.DEFINE_integer(
'unittest_number',
12345,
'Some integer',
lower_bound=0,
flag_values=self.flag_values)
flags.DEFINE_list(
'UnitTestList', '1,2,3', 'Some list', flag_values=self.flag_values)
self.tmp_path = None
self.flag_values.mark_as_parsed()
def tearDown(self):
self._remove_test_files()
def _setup_test_files(self):
"""Creates and sets up some dummy flagfile files with bogus flags."""
# Figure out where to create temporary files
self.assertFalse(self.tmp_path)
self.tmp_path = tempfile.mkdtemp()
tmp_flag_file_1 = open(self.tmp_path + '/UnitTestFile1.tst', 'w')
tmp_flag_file_2 = open(self.tmp_path + '/UnitTestFile2.tst', 'w')
tmp_flag_file_3 = open(self.tmp_path + '/UnitTestFile3.tst', 'w')
tmp_flag_file_4 = open(self.tmp_path + '/UnitTestFile4.tst', 'w')
# put some dummy flags in our test files
tmp_flag_file_1.write('#A Fake Comment\n')
tmp_flag_file_1.write('--unittest_message1=tempFile1!\n')
tmp_flag_file_1.write('\n')
tmp_flag_file_1.write('--unittest_number=54321\n')
tmp_flag_file_1.write('--nounittest_boolflag\n')
file_list = [tmp_flag_file_1.name]
# this one includes test file 1
tmp_flag_file_2.write('//A Different Fake Comment\n')
tmp_flag_file_2.write('--flagfile=%s\n' % tmp_flag_file_1.name)
tmp_flag_file_2.write('--unittest_message2=setFromTempFile2\n')
tmp_flag_file_2.write('\t\t\n')
tmp_flag_file_2.write('--unittest_number=6789a\n')
file_list.append(tmp_flag_file_2.name)
# this file points to itself
tmp_flag_file_3.write('--flagfile=%s\n' % tmp_flag_file_3.name)
tmp_flag_file_3.write('--unittest_message1=setFromTempFile3\n')
tmp_flag_file_3.write('#YAFC\n')
tmp_flag_file_3.write('--unittest_boolflag\n')
file_list.append(tmp_flag_file_3.name)
# this file is unreadable
tmp_flag_file_4.write('--flagfile=%s\n' % tmp_flag_file_3.name)
tmp_flag_file_4.write('--unittest_message1=setFromTempFile4\n')
tmp_flag_file_4.write('--unittest_message1=setFromTempFile4\n')
os.chmod(self.tmp_path + '/UnitTestFile4.tst', 0)
file_list.append(tmp_flag_file_4.name)
tmp_flag_file_1.close()
tmp_flag_file_2.close()
tmp_flag_file_3.close()
tmp_flag_file_4.close()
return file_list # these are just the file names
def _remove_test_files(self):
"""Removes the files we just created."""
if self.tmp_path:
shutil.rmtree(self.tmp_path, ignore_errors=True)
self.tmp_path = None
def _read_flags_from_files(self, argv, force_gnu):
return argv[:1] + self.flag_values.read_flags_from_files(
argv[1:], force_gnu=force_gnu)
#### Flagfile Unit Tests ####
def test_method_flagfiles_1(self):
"""Test trivial case with no flagfile based options."""
fake_cmd_line = 'fooScript --unittest_boolflag'
fake_argv = fake_cmd_line.split(' ')
self.flag_values(fake_argv)
self.assertEqual(self.flag_values.unittest_boolflag, 1)
self.assertListEqual(fake_argv,
self._read_flags_from_files(fake_argv, False))
def test_method_flagfiles_2(self):
"""Tests parsing one file + arguments off simulated argv."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = 'fooScript --q --flagfile=%s' % tmp_files[0]
fake_argv = fake_cmd_line.split(' ')
# We should see the original cmd line with the file's contents spliced in.
# Flags from the file will appear in the order order they are specified
# in the file, in the same position as the flagfile argument.
expected_results = [
'fooScript', '--q', '--unittest_message1=tempFile1!',
'--unittest_number=54321', '--nounittest_boolflag'
]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
# end testTwo def
def test_method_flagfiles_3(self):
"""Tests parsing nested files + arguments of simulated argv."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = ('fooScript --unittest_number=77 --flagfile=%s' %
tmp_files[1])
fake_argv = fake_cmd_line.split(' ')
expected_results = [
'fooScript', '--unittest_number=77', '--unittest_message1=tempFile1!',
'--unittest_number=54321', '--nounittest_boolflag',
'--unittest_message2=setFromTempFile2', '--unittest_number=6789a'
]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
# end testThree def
def test_method_flagfiles_3_spaces(self):
"""Tests parsing nested files + arguments of simulated argv.
The arguments include a pair that is actually an arg with a value, so it
doesn't stop processing.
"""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = ('fooScript --unittest_number 77 --flagfile=%s' %
tmp_files[1])
fake_argv = fake_cmd_line.split(' ')
expected_results = [
'fooScript', '--unittest_number', '77',
'--unittest_message1=tempFile1!', '--unittest_number=54321',
'--nounittest_boolflag', '--unittest_message2=setFromTempFile2',
'--unittest_number=6789a'
]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
def test_method_flagfiles_3_spaces_boolean(self):
"""Tests parsing nested files + arguments of simulated argv.
The arguments include a pair that looks like a --x y arg with value, but
since the flag is a boolean it's actually not.
"""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = ('fooScript --unittest_boolflag 77 --flagfile=%s' %
tmp_files[1])
fake_argv = fake_cmd_line.split(' ')
expected_results = [
'fooScript', '--unittest_boolflag', '77',
'--flagfile=%s' % tmp_files[1]
]
with _use_gnu_getopt(self.flag_values, False):
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
def test_method_flagfiles_4(self):
"""Tests parsing self-referential files + arguments of simulated argv.
This test should print a warning to stderr of some sort.
"""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = ('fooScript --flagfile=%s --nounittest_boolflag' %
tmp_files[2])
fake_argv = fake_cmd_line.split(' ')
expected_results = [
'fooScript', '--unittest_message1=setFromTempFile3',
'--unittest_boolflag', '--nounittest_boolflag'
]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
def test_method_flagfiles_5(self):
"""Test that --flagfile parsing respects the '--' end-of-options marker."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = 'fooScript --some_flag -- --flagfile=%s' % tmp_files[0]
fake_argv = fake_cmd_line.split(' ')
expected_results = [
'fooScript', '--some_flag', '--',
'--flagfile=%s' % tmp_files[0]
]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
def test_method_flagfiles_6(self):
"""Test that --flagfile parsing stops at non-options (non-GNU behavior)."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
tmp_files[0])
fake_argv = fake_cmd_line.split(' ')
expected_results = [
'fooScript', '--some_flag', 'some_arg',
'--flagfile=%s' % tmp_files[0]
]
with _use_gnu_getopt(self.flag_values, False):
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
def test_method_flagfiles_7(self):
"""Test that --flagfile parsing skips over a non-option (GNU behavior)."""
self.flag_values.set_gnu_getopt()
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
tmp_files[0])
fake_argv = fake_cmd_line.split(' ')
expected_results = [
'fooScript', '--some_flag', 'some_arg',
'--unittest_message1=tempFile1!', '--unittest_number=54321',
'--nounittest_boolflag'
]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
def test_method_flagfiles_8(self):
"""Test that --flagfile parsing respects force_gnu=True."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
tmp_files[0])
fake_argv = fake_cmd_line.split(' ')
expected_results = [
'fooScript', '--some_flag', 'some_arg',
'--unittest_message1=tempFile1!', '--unittest_number=54321',
'--nounittest_boolflag'
]
test_results = self._read_flags_from_files(fake_argv, True)
self.assertListEqual(expected_results, test_results)
def test_method_flagfiles_repeated_non_circular(self):
"""Tests that parsing repeated non-circular flagfiles works."""
tmp_files = self._setup_test_files()
# specify our temp files on the fake cmd line
fake_cmd_line = ('fooScript --flagfile=%s --flagfile=%s' %
(tmp_files[1], tmp_files[0]))
fake_argv = fake_cmd_line.split(' ')
expected_results = [
'fooScript', '--unittest_message1=tempFile1!',
'--unittest_number=54321', '--nounittest_boolflag',
'--unittest_message2=setFromTempFile2', '--unittest_number=6789a',
'--unittest_message1=tempFile1!', '--unittest_number=54321',
'--nounittest_boolflag'
]
test_results = self._read_flags_from_files(fake_argv, False)
self.assertListEqual(expected_results, test_results)
@unittest.skipIf(
os.name == 'nt',
'There is no good way to create an unreadable file on Windows.')
def test_method_flagfiles_no_permissions(self):
"""Test that --flagfile raises except on file that is unreadable."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%s' %
tmp_files[3])
fake_argv = fake_cmd_line.split(' ')
self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files,
fake_argv, True)
def test_method_flagfiles_not_found(self):
"""Test that --flagfile raises except on file that does not exist."""
tmp_files = self._setup_test_files()
# specify our temp file on the fake cmd line
fake_cmd_line = ('fooScript --some_flag some_arg --flagfile=%sNOTEXIST' %
tmp_files[3])
fake_argv = fake_cmd_line.split(' ')
self.assertRaises(flags.CantOpenFlagFileError, self._read_flags_from_files,
fake_argv, True)
def test_flagfiles_user_path_expansion(self):
"""Test that user directory referenced paths are correctly expanded.
Test paths like ~/foo. This test depends on whatever account's running
the unit test to have read/write access to their own home directory,
otherwise it'll FAIL.
"""
fake_flagfile_item_style_1 = '--flagfile=~/foo.file'
fake_flagfile_item_style_2 = '-flagfile=~/foo.file'
expected_results = os.path.expanduser('~/foo.file')
test_results = self.flag_values._extract_filename(
fake_flagfile_item_style_1)
self.assertEqual(expected_results, test_results)
test_results = self.flag_values._extract_filename(
fake_flagfile_item_style_2)
self.assertEqual(expected_results, test_results)
def test_no_touchy_non_flags(self):
"""Test that the flags parser does not mutilate arguments.
The arguments are not supposed to be flags
"""
fake_argv = [
'fooScript', '--unittest_boolflag', 'command', '--command_arg1',
'--UnitTestBoom', '--UnitTestB'
]
with _use_gnu_getopt(self.flag_values, False):
argv = self.flag_values(fake_argv)
self.assertListEqual(argv, fake_argv[:1] + fake_argv[2:])
def test_parse_flags_after_args_if_using_gnugetopt(self):
"""Test that flags given after arguments are parsed if using gnu_getopt."""
self.flag_values.set_gnu_getopt()
fake_argv = [
'fooScript', '--unittest_boolflag', 'command', '--unittest_number=54321'
]
argv = self.flag_values(fake_argv)
self.assertListEqual(argv, ['fooScript', 'command'])
def test_set_default(self):
"""Test changing flag defaults."""
# Test that set_default changes both the default and the value,
# and that the value is changed when one is given as an option.
self.flag_values.set_default('unittest_message1', 'New value')
self.assertEqual(self.flag_values.unittest_message1, 'New value')
self.assertEqual(self.flag_values['unittest_message1'].default_as_str,
"'New value'")
self.flag_values(['dummyscript', '--unittest_message1=Newer value'])
self.assertEqual(self.flag_values.unittest_message1, 'Newer value')
# Test that setting the default to None works correctly.
self.flag_values.set_default('unittest_number', None)
self.assertEqual(self.flag_values.unittest_number, None)
self.assertEqual(self.flag_values['unittest_number'].default_as_str, None)
self.flag_values(['dummyscript', '--unittest_number=56'])
self.assertEqual(self.flag_values.unittest_number, 56)
# Test that setting the default to zero works correctly.
self.flag_values.set_default('unittest_number', 0)
self.assertEqual(self.flag_values['unittest_number'].default, 0)
self.assertEqual(self.flag_values.unittest_number, 56)
self.assertEqual(self.flag_values['unittest_number'].default_as_str, "'0'")
self.flag_values(['dummyscript', '--unittest_number=56'])
self.assertEqual(self.flag_values.unittest_number, 56)
# Test that setting the default to '' works correctly.
self.flag_values.set_default('unittest_message1', '')
self.assertEqual(self.flag_values['unittest_message1'].default, '')
self.assertEqual(self.flag_values.unittest_message1, 'Newer value')
self.assertEqual(self.flag_values['unittest_message1'].default_as_str, "''")
self.flag_values(['dummyscript', '--unittest_message1=fifty-six'])
self.assertEqual(self.flag_values.unittest_message1, 'fifty-six')
# Test that setting the default to false works correctly.
self.flag_values.set_default('unittest_boolflag', False)
self.assertEqual(self.flag_values.unittest_boolflag, False)
self.assertEqual(self.flag_values['unittest_boolflag'].default_as_str,
"'false'")
self.flag_values(['dummyscript', '--unittest_boolflag=true'])
self.assertEqual(self.flag_values.unittest_boolflag, True)
# Test that setting a list default works correctly.
self.flag_values.set_default('UnitTestList', '4,5,6')
self.assertListEqual(self.flag_values.UnitTestList, ['4', '5', '6'])
self.assertEqual(self.flag_values['UnitTestList'].default_as_str, "'4,5,6'")
self.flag_values(['dummyscript', '--UnitTestList=7,8,9'])
self.assertListEqual(self.flag_values.UnitTestList, ['7', '8', '9'])
# Test that setting invalid defaults raises exceptions
with self.assertRaises(flags.IllegalFlagValueError):
self.flag_values.set_default('unittest_number', 'oops')
with self.assertRaises(flags.IllegalFlagValueError):
self.flag_values.set_default('unittest_number', -1)
class FlagsParsingTest(absltest.TestCase):
"""Testing different aspects of parsing: '-f' vs '--flag', etc."""
def setUp(self):
self.flag_values = flags.FlagValues()
def test_two_dash_arg_first(self):
flags.DEFINE_string(
'twodash_name', 'Bob', 'namehelp', flag_values=self.flag_values)
flags.DEFINE_string(
'twodash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values)
argv = ('./program', '--', '--twodash_name=Harry')
argv = self.flag_values(argv)
self.assertEqual('Bob', self.flag_values.twodash_name)
self.assertEqual(argv[1], '--twodash_name=Harry')
def test_two_dash_arg_middle(self):
flags.DEFINE_string(
'twodash2_name', 'Bob', 'namehelp', flag_values=self.flag_values)
flags.DEFINE_string(
'twodash2_blame', 'Rob', 'blamehelp', flag_values=self.flag_values)
argv = ('./program', '--twodash2_blame=Larry', '--',
'--twodash2_name=Harry')
argv = self.flag_values(argv)
self.assertEqual('Bob', self.flag_values.twodash2_name)
self.assertEqual('Larry', self.flag_values.twodash2_blame)
self.assertEqual(argv[1], '--twodash2_name=Harry')
def test_one_dash_arg_first(self):
flags.DEFINE_string(
'onedash_name', 'Bob', 'namehelp', flag_values=self.flag_values)
flags.DEFINE_string(
'onedash_blame', 'Rob', 'blamehelp', flag_values=self.flag_values)
argv = ('./program', '-', '--onedash_name=Harry')
with _use_gnu_getopt(self.flag_values, False):
argv = self.flag_values(argv)
self.assertEqual(len(argv), 3)
self.assertEqual(argv[1], '-')
self.assertEqual(argv[2], '--onedash_name=Harry')
def test_required_flag_not_specified(self):
flags.DEFINE_string(
'str_flag',
default=None,
help='help',
required=True,
flag_values=self.flag_values)
argv = ('./program',)
with _use_gnu_getopt(self.flag_values, False):
with self.assertRaises(flags.IllegalFlagValueError):
self.flag_values(argv)
def test_required_arg_works_with_other_validators(self):
flags.DEFINE_integer(
'int_flag',
default=None,
help='help',
required=True,
lower_bound=4,
flag_values=self.flag_values)
argv = ('./program', '--int_flag=2')
with _use_gnu_getopt(self.flag_values, False):
with self.assertRaises(flags.IllegalFlagValueError):
self.flag_values(argv)
def test_unrecognized_flags(self):
flags.DEFINE_string('name', 'Bob', 'namehelp', flag_values=self.flag_values)
# Unknown flag --nosuchflag
try:
argv = ('./program', '--nosuchflag', '--name=Bob', 'extra')
self.flag_values(argv)
raise AssertionError('Unknown flag exception not raised')
except flags.UnrecognizedFlagError as e:
self.assertEqual(e.flagname, 'nosuchflag')
self.assertEqual(e.flagvalue, '--nosuchflag')
# Unknown flag -w (short option)
try:
argv = ('./program', '-w', '--name=Bob', 'extra')
self.flag_values(argv)
raise AssertionError('Unknown flag exception not raised')
except flags.UnrecognizedFlagError as e:
self.assertEqual(e.flagname, 'w')
self.assertEqual(e.flagvalue, '-w')
# Unknown flag --nosuchflagwithparam=foo
try:
argv = ('./program', '--nosuchflagwithparam=foo', '--name=Bob', 'extra')
self.flag_values(argv)
raise AssertionError('Unknown flag exception not raised')
except flags.UnrecognizedFlagError as e:
self.assertEqual(e.flagname, 'nosuchflagwithparam')
self.assertEqual(e.flagvalue, '--nosuchflagwithparam=foo')
# Allow unknown flag --nosuchflag if specified with undefok
argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchflag',
'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
# Allow unknown flag --noboolflag if undefok=boolflag is specified
argv = ('./program', '--noboolflag', '--name=Bob', '--undefok=boolflag',
'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
# But not if the flagname is misspelled:
try:
argv = ('./program', '--nosuchflag', '--name=Bob', '--undefok=nosuchfla',
'extra')
self.flag_values(argv)
raise AssertionError('Unknown flag exception not raised')
except flags.UnrecognizedFlagError as e:
self.assertEqual(e.flagname, 'nosuchflag')
try:
argv = ('./program', '--nosuchflag', '--name=Bob',
'--undefok=nosuchflagg', 'extra')
self.flag_values(argv)
raise AssertionError('Unknown flag exception not raised')
except flags.UnrecognizedFlagError as e:
self.assertEqual(e.flagname, 'nosuchflag')
# Allow unknown short flag -w if specified with undefok
argv = ('./program', '-w', '--name=Bob', '--undefok=w', 'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
# Allow unknown flag --nosuchflagwithparam=foo if specified
# with undefok
argv = ('./program', '--nosuchflagwithparam=foo', '--name=Bob',
'--undefok=nosuchflagwithparam', 'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
# Even if undefok specifies multiple flags
argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo',
'--name=Bob', '--undefok=nosuchflag,w,nosuchflagwithparam', 'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
# However, not if undefok doesn't specify the flag
try:
argv = ('./program', '--nosuchflag', '--name=Bob',
'--undefok=another_such', 'extra')
self.flag_values(argv)
raise AssertionError('Unknown flag exception not raised')
except flags.UnrecognizedFlagError as e:
self.assertEqual(e.flagname, 'nosuchflag')
# Make sure --undefok doesn't mask other option errors.
try:
# Provide an option requiring a parameter but not giving it one.
argv = ('./program', '--undefok=name', '--name')
self.flag_values(argv)
raise AssertionError('Missing option parameter exception not raised')
except flags.UnrecognizedFlagError:
raise AssertionError('Wrong kind of error exception raised')
except flags.Error:
pass
# Test --undefok
argv = ('./program', '--nosuchflag', '-w', '--nosuchflagwithparam=foo',
'--name=Bob', '--undefok', 'nosuchflag,w,nosuchflagwithparam',
'extra')
argv = self.flag_values(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
# Test incorrect --undefok with no value.
argv = ('./program', '--name=Bob', '--undefok')
with self.assertRaises(flags.Error):
self.flag_values(argv)
class NonGlobalFlagsTest(absltest.TestCase):
def test_nonglobal_flags(self):
"""Test use of non-global FlagValues."""
nonglobal_flags = flags.FlagValues()
flags.DEFINE_string('nonglobal_flag', 'Bob', 'flaghelp', nonglobal_flags)
argv = ('./program', '--nonglobal_flag=Mary', 'extra')
argv = nonglobal_flags(argv)
self.assertEqual(len(argv), 2, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
self.assertEqual(argv[1], 'extra', 'extra argument not preserved')
self.assertEqual(nonglobal_flags['nonglobal_flag'].value, 'Mary')
def test_unrecognized_nonglobal_flags(self):
"""Test unrecognized non-global flags."""
nonglobal_flags = flags.FlagValues()
argv = ('./program', '--nosuchflag')
try:
argv = nonglobal_flags(argv)
raise AssertionError('Unknown flag exception not raised')
except flags.UnrecognizedFlagError as e:
self.assertEqual(e.flagname, 'nosuchflag')
argv = ('./program', '--nosuchflag', '--undefok=nosuchflag')
argv = nonglobal_flags(argv)
self.assertEqual(len(argv), 1, 'wrong number of arguments pulled')
self.assertEqual(argv[0], './program', 'program name not preserved')
def test_create_flag_errors(self):
# Since the exception classes are exposed, nothing stops users
# from creating their own instances. This test makes sure that
# people modifying the flags module understand that the external
# mechanisms for creating the exceptions should continue to work.
_ = flags.Error()
_ = flags.Error('message')
_ = flags.DuplicateFlagError()
_ = flags.DuplicateFlagError('message')
_ = flags.IllegalFlagValueError()
_ = flags.IllegalFlagValueError('message')
def test_flag_values_del_attr(self):
"""Checks that del self.flag_values.flag_id works."""
default_value = 'default value for test_flag_values_del_attr'
# 1. Declare and delete a flag with no short name.
flag_values = flags.FlagValues()
flags.DEFINE_string(
'delattr_foo', default_value, 'A simple flag.', flag_values=flag_values)
flag_values.mark_as_parsed()
self.assertEqual(flag_values.delattr_foo, default_value)
flag_obj = flag_values['delattr_foo']
# We also check that _FlagIsRegistered works as expected :)
self.assertTrue(flag_values._flag_is_registered(flag_obj))
del flag_values.delattr_foo
self.assertFalse('delattr_foo' in flag_values._flags())
self.assertFalse(flag_values._flag_is_registered(flag_obj))
# If the previous del FLAGS.delattr_foo did not work properly, the
# next definition will trigger a redefinition error.
flags.DEFINE_integer(
'delattr_foo', 3, 'A simple flag.', flag_values=flag_values)
del flag_values.delattr_foo
self.assertFalse('delattr_foo' in flag_values)
# 2. Declare and delete a flag with a short name.
flags.DEFINE_string(
'delattr_bar',
default_value,
'flag with short name',
short_name='x5',
flag_values=flag_values)
flag_obj = flag_values['delattr_bar']
self.assertTrue(flag_values._flag_is_registered(flag_obj))
del flag_values.x5
self.assertTrue(flag_values._flag_is_registered(flag_obj))
del flag_values.delattr_bar
self.assertFalse(flag_values._flag_is_registered(flag_obj))
# 3. Just like 2, but del flag_values.name last
flags.DEFINE_string(
'delattr_bar',
default_value,
'flag with short name',
short_name='x5',
flag_values=flag_values)
flag_obj = flag_values['delattr_bar']
self.assertTrue(flag_values._flag_is_registered(flag_obj))
del flag_values.delattr_bar
self.assertTrue(flag_values._flag_is_registered(flag_obj))
del flag_values.x5
self.assertFalse(flag_values._flag_is_registered(flag_obj))
self.assertFalse('delattr_bar' in flag_values)
self.assertFalse('x5' in flag_values)
def test_list_flag_format(self):
"""Tests for correctly-formatted list flags."""
fv = flags.FlagValues()
flags.DEFINE_list('listflag', '', 'A list of arguments', flag_values=fv)
def _check_parsing(listval):
"""Parse a particular value for our test flag, --listflag."""
argv = fv(['./program', '--listflag=' + listval, 'plain-arg'])
self.assertEqual(['./program', 'plain-arg'], argv)
return fv.listflag
# Basic success case
self.assertEqual(_check_parsing('foo,bar'), ['foo', 'bar'])
# Success case: newline in argument is quoted.
self.assertEqual(_check_parsing('"foo","bar\nbar"'), ['foo', 'bar\nbar'])
# Failure case: newline in argument is unquoted.
self.assertRaises(flags.IllegalFlagValueError, _check_parsing,
'"foo",bar\nbar')
# Failure case: unmatched ".
self.assertRaises(flags.IllegalFlagValueError, _check_parsing,
'"foo,barbar')
def test_flag_definition_via_setitem(self):
with self.assertRaises(flags.IllegalFlagValueError):
flag_values = flags.FlagValues()
flag_values['flag_name'] = 'flag_value' # type: ignore
class SetDefaultTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.flag_values = flags.FlagValues()
def test_success(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values)
flags.set_default(int_holder, 2)
self.flag_values.mark_as_parsed()
self.assertEqual(int_holder.value, 2)
def test_update_after_parse(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values)
self.flag_values.mark_as_parsed()
flags.set_default(int_holder, 2)
self.assertEqual(int_holder.value, 2)
def test_overridden_by_explicit_assignment(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values)
self.flag_values.mark_as_parsed()
self.flag_values.an_int = 3
flags.set_default(int_holder, 2)
self.assertEqual(int_holder.value, 3)
def test_restores_back_to_none(self):
int_holder = flags.DEFINE_integer(
'an_int', None, 'an int', flag_values=self.flag_values)
self.flag_values.mark_as_parsed()
flags.set_default(int_holder, 3)
flags.set_default(int_holder, None)
self.assertIsNone(int_holder.value)
def test_failure_on_invalid_type(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values)
self.flag_values.mark_as_parsed()
with self.assertRaises(flags.IllegalFlagValueError):
flags.set_default(int_holder, 'a') # type: ignore
def test_failure_on_type_protected_none_default(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values)
self.flag_values.mark_as_parsed()
flags.set_default(int_holder, None) # type: ignore
with self.assertRaises(flags.IllegalFlagValueError):
_ = int_holder.value # Will also fail on later access.
class OverrideValueTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.flag_values = flags.FlagValues()
def test_success(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values
)
flags.override_value(int_holder, 2)
self.flag_values.mark_as_parsed()
self.assertEqual(int_holder.value, 2)
def test_update_after_parse(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values
)
self.flag_values.mark_as_parsed()
flags.override_value(int_holder, 2)
self.assertEqual(int_holder.value, 2)
def test_overrides_explicit_assignment(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values
)
self.flag_values.mark_as_parsed()
self.flag_values.an_int = 3
flags.override_value(int_holder, 2)
self.assertEqual(int_holder.value, 2)
def test_overriden_by_explicit_assignment(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values
)
self.flag_values.mark_as_parsed()
flags.override_value(int_holder, 2)
self.flag_values.an_int = 3
self.assertEqual(int_holder.value, 3)
def test_multi_flag(self):
multi_holder = flags.DEFINE_multi_string(
'strs', [], 'some strs', flag_values=self.flag_values
)
flags.override_value(multi_holder, ['a', 'b'])
self.flag_values.mark_as_parsed()
self.assertEqual(multi_holder.value, ['a', 'b'])
def test_failure_on_invalid_type(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values
)
self.flag_values.mark_as_parsed()
with self.assertRaises(flags.IllegalFlagValueError):
flags.override_value(int_holder, 'a') # pytype: disable=wrong-arg-types
self.assertEqual(int_holder.value, 1)
def test_failure_on_unparsed_value(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values
)
self.flag_values.mark_as_parsed()
with self.assertRaises(flags.IllegalFlagValueError):
flags.override_value(int_holder, '2') # pytype: disable=wrong-arg-types
def test_failure_on_parser_rejection(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values, upper_bound=5
)
self.flag_values.mark_as_parsed()
with self.assertRaises(flags.IllegalFlagValueError):
flags.override_value(int_holder, 6)
self.assertEqual(int_holder.value, 1)
def test_failure_on_validator_rejection(self):
int_holder = flags.DEFINE_integer(
'an_int', 1, 'an int', flag_values=self.flag_values
)
flags.register_validator(
int_holder.name, lambda x: x < 5, flag_values=self.flag_values
)
self.flag_values.mark_as_parsed()
with self.assertRaises(flags.IllegalFlagValueError):
flags.override_value(int_holder, 6)
self.assertEqual(int_holder.value, 1)
class KeyFlagsTest(absltest.TestCase):
def setUp(self):
self.flag_values = flags.FlagValues()
def _get_names_of_defined_flags(self, module, flag_values):
"""Returns the list of names of flags defined by a module.
Auxiliary for the test_key_flags* methods.
Args:
module: A module object or a string module name.
flag_values: A FlagValues object.
Returns:
A list of strings.
"""
return [f.name for f in flag_values.get_flags_for_module(module)]
def _get_names_of_key_flags(self, module, flag_values):
"""Returns the list of names of key flags for a module.
Auxiliary for the test_key_flags* methods.
Args:
module: A module object or a string module name.
flag_values: A FlagValues object.
Returns:
A list of strings.
"""
return [f.name for f in flag_values.get_key_flags_for_module(module)]
def _assert_lists_have_same_elements(self, list_1, list_2):
# Checks that two lists have the same elements with the same
# multiplicity, in possibly different order.
list_1 = list(list_1)
list_1.sort()
list_2 = list(list_2)
list_2.sort()
self.assertListEqual(list_1, list_2)
def test_key_flags(self):
flag_values = flags.FlagValues()
# Before starting any testing, make sure no flags are already
# defined for module_foo and module_bar.
self.assertListEqual(
self._get_names_of_key_flags(module_foo, flag_values), [])
self.assertListEqual(
self._get_names_of_key_flags(module_bar, flag_values), [])
self.assertListEqual(
self._get_names_of_defined_flags(module_foo, flag_values), [])
self.assertListEqual(
self._get_names_of_defined_flags(module_bar, flag_values), [])
# Defines a few flags in module_foo and module_bar.
module_foo.define_flags(flag_values=flag_values)
try:
# Part 1. Check that all flags defined by module_foo are key for
# that module, and similarly for module_bar.
for module in [module_foo, module_bar]:
self._assert_lists_have_same_elements(
flag_values.get_flags_for_module(module),
flag_values.get_key_flags_for_module(module))
# Also check that each module defined the expected flags.
self._assert_lists_have_same_elements(
self._get_names_of_defined_flags(module, flag_values),
module.names_of_defined_flags())
# Part 2. Check that flags.declare_key_flag works fine.
# Declare that some flags from module_bar are key for
# module_foo.
module_foo.declare_key_flags(flag_values=flag_values)
# Check that module_foo has the expected list of defined flags.
self._assert_lists_have_same_elements(
self._get_names_of_defined_flags(module_foo, flag_values),
module_foo.names_of_defined_flags())
# Check that module_foo has the expected list of key flags.
self._assert_lists_have_same_elements(
self._get_names_of_key_flags(module_foo, flag_values),
module_foo.names_of_declared_key_flags())
# Part 3. Check that flags.adopt_module_key_flags works fine.
# Trigger a call to flags.adopt_module_key_flags(module_bar)
# inside module_foo. This should declare a few more key
# flags in module_foo.
module_foo.declare_extra_key_flags(flag_values=flag_values)
# Check that module_foo has the expected list of key flags.
self._assert_lists_have_same_elements(
self._get_names_of_key_flags(module_foo, flag_values),
module_foo.names_of_declared_key_flags() +
module_foo.names_of_declared_extra_key_flags())
finally:
module_foo.remove_flags(flag_values=flag_values)
def test_key_flags_with_non_default_flag_values_object(self):
# Check that key flags work even when we use a FlagValues object
# that is not the default flags.self.flag_values object. Otherwise, this
# test is similar to test_key_flags, but it uses only module_bar.
# The other test module (module_foo) uses only the default values
# for the flag_values keyword arguments. This way, test_key_flags
# and this method test both the default FlagValues, the explicitly
# specified one, and a mixed usage of the two.
# A brand-new FlagValues object, to use instead of flags.self.flag_values.
fv = flags.FlagValues()
# Before starting any testing, make sure no flags are already
# defined for module_foo and module_bar.
self.assertListEqual(self._get_names_of_key_flags(module_bar, fv), [])
self.assertListEqual(self._get_names_of_defined_flags(module_bar, fv), [])
module_bar.define_flags(flag_values=fv)
# Check that all flags defined by module_bar are key for that
# module, and that module_bar defined the expected flags.
self._assert_lists_have_same_elements(
fv.get_flags_for_module(module_bar),
fv.get_key_flags_for_module(module_bar))
self._assert_lists_have_same_elements(
self._get_names_of_defined_flags(module_bar, fv),
module_bar.names_of_defined_flags())
# Pick two flags from module_bar, declare them as key for the
# current (i.e., main) module (via flags.declare_key_flag), and
# check that we get the expected effect. The important thing is
# that we always use flags_values=fv (instead of the default
# self.flag_values).
main_module = sys.argv[0]
names_of_flags_defined_by_bar = module_bar.names_of_defined_flags()
flag_name_0 = names_of_flags_defined_by_bar[0]
flag_name_2 = names_of_flags_defined_by_bar[2]
flags.declare_key_flag(flag_name_0, flag_values=fv)
self._assert_lists_have_same_elements(
self._get_names_of_key_flags(main_module, fv), [flag_name_0])
flags.declare_key_flag(flag_name_2, flag_values=fv)
self._assert_lists_have_same_elements(
self._get_names_of_key_flags(main_module, fv),
[flag_name_0, flag_name_2])
# Try with a special (not user-defined) flag too:
flags.declare_key_flag('undefok', flag_values=fv)
self._assert_lists_have_same_elements(
self._get_names_of_key_flags(main_module, fv),
[flag_name_0, flag_name_2, 'undefok'])
flags.adopt_module_key_flags(module_bar, fv)
self._assert_lists_have_same_elements(
self._get_names_of_key_flags(main_module, fv),
names_of_flags_defined_by_bar + ['undefok'])
# Adopt key flags from the flags module itself.
flags.adopt_module_key_flags(flags, flag_values=fv)
self._assert_lists_have_same_elements(
self._get_names_of_key_flags(main_module, fv),
names_of_flags_defined_by_bar + ['flagfile', 'undefok'])
def test_key_flags_with_flagholders(self):
main_module = sys.argv[0]
self.assertListEqual(
self._get_names_of_key_flags(main_module, self.flag_values), [])
self.assertListEqual(
self._get_names_of_defined_flags(main_module, self.flag_values), [])
int_holder = flags.DEFINE_integer(
'main_module_int_fg',
1,
'Integer flag in the main module.',
flag_values=self.flag_values)
flags.declare_key_flag(int_holder, self.flag_values)
self.assertCountEqual(
self.flag_values.get_flags_for_module(main_module),
self.flag_values.get_key_flags_for_module(main_module))
bool_holder = flags.DEFINE_boolean(
'main_module_bool_fg',
False,
'Boolean flag in the main module.',
flag_values=self.flag_values)
flags.declare_key_flag(bool_holder) # omitted flag_values
self.assertCountEqual(
self.flag_values.get_flags_for_module(main_module),
self.flag_values.get_key_flags_for_module(main_module))
self.assertLen(self.flag_values.get_flags_for_module(main_module), 2)
def test_main_module_help_with_key_flags(self):
# Similar to test_main_module_help, but this time we make sure to
# declare some key flags.
# Safety check that the main module does not declare any flags
# at the beginning of this test.
expected_help = ''
self.assertMultiLineEqual(expected_help,
self.flag_values.main_module_help())
# Define one flag in this main module and some flags in modules
# a and b. Also declare one flag from module a and one flag
# from module b as key flags for the main module.
flags.DEFINE_integer(
'main_module_int_fg',
1,
'Integer flag in the main module.',
flag_values=self.flag_values)
try:
main_module_int_fg_help = (
' --main_module_int_fg: Integer flag in the main module.\n'
" (default: '1')\n"
' (an integer)')
expected_help += '\n%s:\n%s' % (sys.argv[0], main_module_int_fg_help)
self.assertMultiLineEqual(expected_help,
self.flag_values.main_module_help())
# The following call should be a no-op: any flag declared by a
# module is automatically key for that module.
flags.declare_key_flag('main_module_int_fg', flag_values=self.flag_values)
self.assertMultiLineEqual(expected_help,
self.flag_values.main_module_help())
# The definition of a few flags in an imported module should not
# change the main module help.
module_foo.define_flags(flag_values=self.flag_values)
self.assertMultiLineEqual(expected_help,
self.flag_values.main_module_help())
flags.declare_key_flag('tmod_foo_bool', flag_values=self.flag_values)
tmod_foo_bool_help = (
' --[no]tmod_foo_bool: Boolean flag from module foo.\n'
" (default: 'true')")
expected_help += '\n' + tmod_foo_bool_help
self.assertMultiLineEqual(expected_help,
self.flag_values.main_module_help())
flags.declare_key_flag('tmod_bar_z', flag_values=self.flag_values)
tmod_bar_z_help = (
' --[no]tmod_bar_z: Another boolean flag from module bar.\n'
" (default: 'false')")
# Unfortunately, there is some flag sorting inside
# main_module_help, so we can't keep incrementally extending
# the expected_help string ...
expected_help = ('\n%s:\n%s\n%s\n%s' %
(sys.argv[0], main_module_int_fg_help, tmod_bar_z_help,
tmod_foo_bool_help))
self.assertMultiLineEqual(self.flag_values.main_module_help(),
expected_help)
finally:
# At the end, delete all the flag information we created.
self.flag_values.__delattr__('main_module_int_fg')
module_foo.remove_flags(flag_values=self.flag_values)
def test_adoptmodule_key_flags(self):
# Check that adopt_module_key_flags raises an exception when
# called with a module name (as opposed to a module object).
self.assertRaises(flags.Error, flags.adopt_module_key_flags, 'pyglib.app')
def test_disclaimkey_flags(self):
original_disclaim_module_ids = _helpers.disclaim_module_ids
_helpers.disclaim_module_ids = set(_helpers.disclaim_module_ids)
try:
module_bar.disclaim_key_flags()
module_foo.define_bar_flags(flag_values=self.flag_values)
module_name = self.flag_values.find_module_defining_flag('tmod_bar_x')
self.assertEqual(module_foo.__name__, module_name)
finally:
_helpers.disclaim_module_ids = original_disclaim_module_ids
class FindModuleTest(absltest.TestCase):
"""Testing methods that find a module that defines a given flag."""
def test_find_module_defining_flag(self):
self.assertEqual(
'default',
FLAGS.find_module_defining_flag('__NON_EXISTENT_FLAG__', 'default'))
self.assertEqual(module_baz.__name__,
FLAGS.find_module_defining_flag('tmod_baz_x'))
def test_find_module_id_defining_flag(self):
self.assertEqual(
'default',
FLAGS.find_module_id_defining_flag('__NON_EXISTENT_FLAG__', 'default'))
self.assertEqual(
id(module_baz), FLAGS.find_module_id_defining_flag('tmod_baz_x'))
def test_find_module_defining_flag_passing_module_name(self):
my_flags = flags.FlagValues()
module_name = sys.__name__ # Must use an existing module.
flags.DEFINE_boolean(
'flag_name',
True,
'Flag with a different module name.',
flag_values=my_flags,
module_name=module_name)
self.assertEqual(module_name,
my_flags.find_module_defining_flag('flag_name'))
def test_find_module_id_defining_flag_passing_module_name(self):
my_flags = flags.FlagValues()
module_name = sys.__name__ # Must use an existing module.
flags.DEFINE_boolean(
'flag_name',
True,
'Flag with a different module name.',
flag_values=my_flags,
module_name=module_name)
self.assertEqual(
id(sys), my_flags.find_module_id_defining_flag('flag_name'))
class FlagsErrorMessagesTest(absltest.TestCase):
"""Testing special cases for integer and float flags error messages."""
def setUp(self):
self.flag_values = flags.FlagValues()
def test_integer_error_text(self):
# Make sure we get proper error text
flags.DEFINE_integer(
'positive',
4,
'non-negative flag',
lower_bound=1,
flag_values=self.flag_values)
flags.DEFINE_integer(
'non_negative',
4,
'positive flag',
lower_bound=0,
flag_values=self.flag_values)
flags.DEFINE_integer(
'negative',
-4,
'negative flag',
upper_bound=-1,
flag_values=self.flag_values)
flags.DEFINE_integer(
'non_positive',
-4,
'non-positive flag',
upper_bound=0,
flag_values=self.flag_values)
flags.DEFINE_integer(
'greater',
19,
'greater-than flag',
lower_bound=4,
flag_values=self.flag_values)
flags.DEFINE_integer(
'smaller',
-19,
'smaller-than flag',
upper_bound=4,
flag_values=self.flag_values)
flags.DEFINE_integer(
'usual',
4,
'usual flag',
lower_bound=0,
upper_bound=10000,
flag_values=self.flag_values)
flags.DEFINE_integer(
'another_usual',
0,
'usual flag',
lower_bound=-1,
upper_bound=1,
flag_values=self.flag_values)
self._check_error_message('positive', -4, 'a positive integer')
self._check_error_message('non_negative', -4, 'a non-negative integer')
self._check_error_message('negative', 0, 'a negative integer')
self._check_error_message('non_positive', 4, 'a non-positive integer')
self._check_error_message('usual', -4, 'an integer in the range [0, 10000]')
self._check_error_message('another_usual', 4,
'an integer in the range [-1, 1]')
self._check_error_message('greater', -5, 'integer >= 4')
self._check_error_message('smaller', 5, 'integer <= 4')
def test_float_error_text(self):
flags.DEFINE_float(
'positive',
4,
'non-negative flag',
lower_bound=1,
flag_values=self.flag_values)
flags.DEFINE_float(
'non_negative',
4,
'positive flag',
lower_bound=0,
flag_values=self.flag_values)
flags.DEFINE_float(
'negative',
-4,
'negative flag',
upper_bound=-1,
flag_values=self.flag_values)
flags.DEFINE_float(
'non_positive',
-4,
'non-positive flag',
upper_bound=0,
flag_values=self.flag_values)
flags.DEFINE_float(
'greater',
19,
'greater-than flag',
lower_bound=4,
flag_values=self.flag_values)
flags.DEFINE_float(
'smaller',
-19,
'smaller-than flag',
upper_bound=4,
flag_values=self.flag_values)
flags.DEFINE_float(
'usual',
4,
'usual flag',
lower_bound=0,
upper_bound=10000,
flag_values=self.flag_values)
flags.DEFINE_float(
'another_usual',
0,
'usual flag',
lower_bound=-1,
upper_bound=1,
flag_values=self.flag_values)
self._check_error_message('positive', 0.5, 'number >= 1')
self._check_error_message('non_negative', -4.0, 'a non-negative number')
self._check_error_message('negative', 0.5, 'number <= -1')
self._check_error_message('non_positive', 4.0, 'a non-positive number')
self._check_error_message('usual', -4.0, 'a number in the range [0, 10000]')
self._check_error_message('another_usual', 4.0,
'a number in the range [-1, 1]')
self._check_error_message('smaller', 5.0, 'number <= 4')
def _check_error_message(self, flag_name, flag_value,
expected_message_suffix):
"""Set a flag to a given value and make sure we get expected message."""
try:
self.flag_values.__setattr__(flag_name, flag_value)
raise AssertionError('Bounds exception not raised!')
except flags.IllegalFlagValueError as e:
expected = ('flag --%(name)s=%(value)s: %(value)s is not %(suffix)s' % {
'name': flag_name,
'value': flag_value,
'suffix': expected_message_suffix
})
self.assertEqual(str(e), expected)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/flags_unicode_literals_test.py 0000664 0000000 0000000 00000002467 14551576331 0024713 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the use of flags when from __future__ import unicode_literals is on."""
from absl import flags
from absl.testing import absltest
flags.DEFINE_string('seen_in_crittenden', 'alleged mountain lion',
'This tests if unicode input to these functions works.')
class FlagsUnicodeLiteralsTest(absltest.TestCase):
def testUnicodeFlagNameAndValueAreGood(self):
alleged_mountain_lion = flags.FLAGS.seen_in_crittenden
self.assertTrue(
isinstance(alleged_mountain_lion, type(u'')),
msg='expected flag value to be a {} not {}'.format(
type(u''), type(alleged_mountain_lion)))
self.assertEqual(alleged_mountain_lion, u'alleged mountain lion')
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/flags/tests/module_bar.py 0000664 0000000 0000000 00000007332 14551576331 0021260 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Auxiliary module for testing flags.py.
The purpose of this module is to define a few flags. We want to make
sure the unit tests for flags.py involve more than one module.
"""
from absl import flags
from absl.flags import _helpers
FLAGS = flags.FLAGS
def define_flags(flag_values=FLAGS):
"""Defines some flags.
Args:
flag_values: The FlagValues object we want to register the flags
with.
"""
# The 'tmod_bar_' prefix (short for 'test_module_bar') ensures there
# is no name clash with the existing flags.
flags.DEFINE_boolean('tmod_bar_x', True, 'Boolean flag.',
flag_values=flag_values)
flags.DEFINE_string('tmod_bar_y', 'default', 'String flag.',
flag_values=flag_values)
flags.DEFINE_boolean('tmod_bar_z', False,
'Another boolean flag from module bar.',
flag_values=flag_values)
flags.DEFINE_integer('tmod_bar_t', 4, 'Sample int flag.',
flag_values=flag_values)
flags.DEFINE_integer('tmod_bar_u', 5, 'Sample int flag.',
flag_values=flag_values)
flags.DEFINE_integer('tmod_bar_v', 6, 'Sample int flag.',
flag_values=flag_values)
def remove_one_flag(flag_name, flag_values=FLAGS):
"""Removes the definition of one flag from flags.FLAGS.
Note: if the flag is not defined in flags.FLAGS, this function does
not do anything (in particular, it does not raise any exception).
Motivation: We use this function for cleanup *after* a test: if
there was a failure during a test and not all flags were declared,
we do not want the cleanup code to crash.
Args:
flag_name: A string, the name of the flag to delete.
flag_values: The FlagValues object we remove the flag from.
"""
if flag_name in flag_values:
flag_values.__delattr__(flag_name)
def names_of_defined_flags():
"""Returns: List of names of the flags declared in this module."""
return ['tmod_bar_x',
'tmod_bar_y',
'tmod_bar_z',
'tmod_bar_t',
'tmod_bar_u',
'tmod_bar_v']
def remove_flags(flag_values=FLAGS):
"""Deletes the flag definitions done by the above define_flags().
Args:
flag_values: The FlagValues object we remove the flags from.
"""
for flag_name in names_of_defined_flags():
remove_one_flag(flag_name, flag_values=flag_values)
def get_module_name():
"""Uses get_calling_module() to return the name of this module.
For checking that get_calling_module works as expected.
Returns:
A string, the name of this module.
"""
return _helpers.get_calling_module()
def execute_code(code, global_dict):
"""Executes some code in a given global environment.
For testing of get_calling_module.
Args:
code: A string, the code to be executed.
global_dict: A dictionary, the global environment that code should
be executed in.
"""
# Indeed, using exec generates a lint warning. But some user code
# actually uses exec, and we have to test for it ...
exec(code, global_dict) # pylint: disable=exec-used
def disclaim_key_flags():
"""Disclaims flags declared in this module."""
flags.disclaim_key_flags()
abseil-py-2.1.0/absl/flags/tests/module_baz.py 0000664 0000000 0000000 00000001505 14551576331 0021264 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Auxiliary module for testing flags.py.
The purpose of this module is to test the behavior of flags that are defined
before main() executes.
"""
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean('tmod_baz_x', True, 'Boolean flag.')
abseil-py-2.1.0/absl/flags/tests/module_foo.py 0000664 0000000 0000000 00000010041 14551576331 0021266 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Auxiliary module for testing flags.py.
The purpose of this module is to define a few flags, and declare some
other flags as being important. We want to make sure the unit tests
for flags.py involve more than one module.
"""
from absl import flags
from absl.flags import _helpers
from absl.flags.tests import module_bar
FLAGS = flags.FLAGS
DECLARED_KEY_FLAGS = ['tmod_bar_x', 'tmod_bar_z', 'tmod_bar_t',
# Special (not user-defined) flag:
'flagfile']
def define_flags(flag_values=FLAGS):
"""Defines a few flags."""
module_bar.define_flags(flag_values=flag_values)
# The 'tmod_foo_' prefix (short for 'test_module_foo') ensures that we
# have no name clash with existing flags.
flags.DEFINE_boolean('tmod_foo_bool', True, 'Boolean flag from module foo.',
flag_values=flag_values)
flags.DEFINE_string('tmod_foo_str', 'default', 'String flag.',
flag_values=flag_values)
flags.DEFINE_integer('tmod_foo_int', 3, 'Sample int flag.',
flag_values=flag_values)
def declare_key_flags(flag_values=FLAGS):
"""Declares a few key flags."""
for flag_name in DECLARED_KEY_FLAGS:
flags.declare_key_flag(flag_name, flag_values=flag_values)
def declare_extra_key_flags(flag_values=FLAGS):
"""Declares some extra key flags."""
flags.adopt_module_key_flags(module_bar, flag_values=flag_values)
def names_of_defined_flags():
"""Returns: list of names of flags defined by this module."""
return ['tmod_foo_bool', 'tmod_foo_str', 'tmod_foo_int']
def names_of_declared_key_flags():
"""Returns: list of names of key flags for this module."""
return names_of_defined_flags() + DECLARED_KEY_FLAGS
def names_of_declared_extra_key_flags():
"""Returns the list of names of additional key flags for this module.
These are the flags that became key for this module only as a result
of a call to declare_extra_key_flags() above. I.e., the flags declared
by module_bar, that were not already declared as key for this
module.
Returns:
The list of names of additional key flags for this module.
"""
names_of_extra_key_flags = list(module_bar.names_of_defined_flags())
for flag_name in names_of_declared_key_flags():
while flag_name in names_of_extra_key_flags:
names_of_extra_key_flags.remove(flag_name)
return names_of_extra_key_flags
def remove_flags(flag_values=FLAGS):
"""Deletes the flag definitions done by the above define_flags()."""
for flag_name in names_of_defined_flags():
module_bar.remove_one_flag(flag_name, flag_values=flag_values)
module_bar.remove_flags(flag_values=flag_values)
def get_module_name():
"""Uses get_calling_module() to return the name of this module.
For checking that _get_calling_module works as expected.
Returns:
A string, the name of this module.
"""
return _helpers.get_calling_module()
def duplicate_flags(flagnames=None):
"""Returns a new FlagValues object with the requested flagnames.
Used to test DuplicateFlagError detection.
Args:
flagnames: str, A list of flag names to create.
Returns:
A FlagValues object with one boolean flag for each name in flagnames.
"""
flag_values = flags.FlagValues()
for name in flagnames:
flags.DEFINE_boolean(name, False, 'Flag named %s' % (name,),
flag_values=flag_values)
return flag_values
def define_bar_flags(flag_values=FLAGS):
"""Defines flags from module_bar."""
module_bar.define_flags(flag_values)
abseil-py-2.1.0/absl/logging/ 0000775 0000000 0000000 00000000000 14551576331 0015760 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/absl/logging/BUILD 0000664 0000000 0000000 00000004541 14551576331 0016546 0 ustar 00root root 0000000 0000000 load("@rules_python//python:py_library.bzl", "py_library")
load("@rules_python//python:py_test.bzl", "py_test")
load("@rules_python//python:py_binary.bzl", "py_binary")
package(default_visibility = ["//visibility:private"])
licenses(["notice"])
py_library(
name = "logging",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":converter",
"//absl/flags",
],
)
py_library(
name = "converter",
srcs = ["converter.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
)
py_test(
name = "tests/converter_test",
size = "small",
srcs = ["tests/converter_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":converter",
":logging",
"//absl/testing:absltest",
],
)
py_test(
name = "tests/logging_test",
size = "small",
srcs = ["tests/logging_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":logging",
"//absl/flags",
"//absl/testing:absltest",
"//absl/testing:flagsaver",
"//absl/testing:parameterized",
],
)
py_test(
name = "tests/log_before_import_test",
srcs = ["tests/log_before_import_test.py"],
main = "tests/log_before_import_test.py",
python_version = "PY3",
srcs_version = "PY3",
deps = [
":logging",
"//absl/testing:absltest",
],
)
py_test(
name = "tests/verbosity_flag_test",
srcs = ["tests/verbosity_flag_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":logging",
"//absl/flags",
"//absl/testing:absltest",
],
)
py_binary(
name = "tests/logging_functional_test_helper",
testonly = 1,
srcs = ["tests/logging_functional_test_helper.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":logging",
"//absl:app",
"//absl/flags",
],
)
py_test(
name = "tests/logging_functional_test",
size = "large",
srcs = ["tests/logging_functional_test.py"],
data = [":tests/logging_functional_test_helper"],
python_version = "PY3",
shard_count = 50,
srcs_version = "PY3",
deps = [
":logging",
"//absl/testing:_bazelize_command",
"//absl/testing:absltest",
"//absl/testing:parameterized",
],
)
abseil-py-2.1.0/absl/logging/__init__.py 0000664 0000000 0000000 00000122142 14551576331 0020073 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abseil Python logging module implemented on top of standard logging.
Simple usage::
from absl import logging
logging.info('Interesting Stuff')
logging.info('Interesting Stuff with Arguments: %d', 42)
logging.set_verbosity(logging.INFO)
logging.log(logging.DEBUG, 'This will *not* be printed')
logging.set_verbosity(logging.DEBUG)
logging.log(logging.DEBUG, 'This will be printed')
logging.warning('Worrying Stuff')
logging.error('Alarming Stuff')
logging.fatal('AAAAHHHHH!!!!') # Process exits.
Usage note: Do not pre-format the strings in your program code.
Instead, let the logging module perform argument interpolation.
This saves cycles because strings that don't need to be printed
are never formatted. Note that this module does not attempt to
interpolate arguments when no arguments are given. In other words::
logging.info('Interesting Stuff: %s')
does not raise an exception because logging.info() has only one
argument, the message string.
"Lazy" evaluation for debugging
-------------------------------
If you do something like this::
logging.debug('Thing: %s', thing.ExpensiveOp())
then the ExpensiveOp will be evaluated even if nothing
is printed to the log. To avoid this, use the level_debug() function::
if logging.level_debug():
logging.debug('Thing: %s', thing.ExpensiveOp())
Per file level logging is supported by logging.vlog() and
logging.vlog_is_on(). For example::
if logging.vlog_is_on(2):
logging.vlog(2, very_expensive_debug_message())
Notes on Unicode
----------------
The log output is encoded as UTF-8. Don't pass data in other encodings in
bytes() instances -- instead pass unicode string instances when you need to
(for both the format string and arguments).
Note on critical and fatal:
Standard logging module defines fatal as an alias to critical, but it's not
documented, and it does NOT actually terminate the program.
This module only defines fatal but not critical, and it DOES terminate the
program.
The differences in behavior are historical and unfortunate.
"""
import collections
from collections import abc
import getpass
import io
import itertools
import logging
import os
import socket
import struct
import sys
import tempfile
import threading
import time
import timeit
import traceback
import types
import warnings
from absl import flags
from absl.logging import converter
# pylint: disable=g-import-not-at-top
try:
from typing import NoReturn
except ImportError:
pass
# pylint: enable=g-import-not-at-top
FLAGS = flags.FLAGS
# Logging levels.
FATAL = converter.ABSL_FATAL
ERROR = converter.ABSL_ERROR
WARNING = converter.ABSL_WARNING
WARN = converter.ABSL_WARNING # Deprecated name.
INFO = converter.ABSL_INFO
DEBUG = converter.ABSL_DEBUG
# Regex to match/parse log line prefixes.
ABSL_LOGGING_PREFIX_REGEX = (
r'^(?P[IWEF])'
r'(?P\d\d)(?P\d\d) '
r'(?P\d\d):(?P\d\d):(?P\d\d)'
r'\.(?P\d\d\d\d\d\d) +'
r'(?P-?\d+) '
r'(?P[a-zA-Z<][\w._<>-]+):(?P\d+)')
# Mask to convert integer thread ids to unsigned quantities for logging purposes
_THREAD_ID_MASK = 2 ** (struct.calcsize('L') * 8) - 1
# Extra property set on the LogRecord created by ABSLLogger when its level is
# CRITICAL/FATAL.
_ABSL_LOG_FATAL = '_absl_log_fatal'
# Extra prefix added to the log message when a non-absl logger logs a
# CRITICAL/FATAL message.
_CRITICAL_PREFIX = 'CRITICAL - '
# Used by findCaller to skip callers from */logging/__init__.py.
_LOGGING_FILE_PREFIX = os.path.join('logging', '__init__.')
# The ABSL logger instance, initialized in _initialize().
_absl_logger = None
# The ABSL handler instance, initialized in _initialize().
_absl_handler = None
_CPP_NAME_TO_LEVELS = {
'debug': '0', # Abseil C++ has no DEBUG level, mapping it to INFO here.
'info': '0',
'warning': '1',
'warn': '1',
'error': '2',
'fatal': '3'
}
_CPP_LEVEL_TO_NAMES = {
'0': 'info',
'1': 'warning',
'2': 'error',
'3': 'fatal',
}
class _VerbosityFlag(flags.Flag):
"""Flag class for -v/--verbosity."""
def __init__(self, *args, **kwargs):
super(_VerbosityFlag, self).__init__(
flags.IntegerParser(),
flags.ArgumentSerializer(),
*args, **kwargs)
@property
def value(self):
return self._value
@value.setter
def value(self, v):
self._value = v
self._update_logging_levels()
def _update_logging_levels(self):
"""Updates absl logging levels to the current verbosity.
Visibility: module-private
"""
if not _absl_logger:
return
if self._value <= converter.ABSL_DEBUG:
standard_verbosity = converter.absl_to_standard(self._value)
else:
# --verbosity is set to higher than 1 for vlog.
standard_verbosity = logging.DEBUG - (self._value - 1)
# Also update root level when absl_handler is used.
if _absl_handler in logging.root.handlers:
# Make absl logger inherit from the root logger. absl logger might have
# a non-NOTSET value if logging.set_verbosity() is called at import time.
_absl_logger.setLevel(logging.NOTSET)
logging.root.setLevel(standard_verbosity)
else:
_absl_logger.setLevel(standard_verbosity)
class _LoggerLevelsFlag(flags.Flag):
"""Flag class for --logger_levels."""
def __init__(self, *args, **kwargs):
super(_LoggerLevelsFlag, self).__init__(
_LoggerLevelsParser(),
_LoggerLevelsSerializer(),
*args, **kwargs)
@property
def value(self):
# For lack of an immutable type, be defensive and return a copy.
# Modifications to the dict aren't supported and won't have any affect.
# While Py3 could use MappingProxyType, that isn't deepcopy friendly, so
# just return a copy.
return self._value.copy()
@value.setter
def value(self, v):
self._value = {} if v is None else v
self._update_logger_levels()
def _update_logger_levels(self):
# Visibility: module-private.
# This is called by absl.app.run() during initialization.
for name, level in self._value.items():
logging.getLogger(name).setLevel(level)
class _LoggerLevelsParser(flags.ArgumentParser):
"""Parser for --logger_levels flag."""
def parse(self, value):
if isinstance(value, abc.Mapping):
return value
pairs = [pair.strip() for pair in value.split(',') if pair.strip()]
# Preserve the order so that serialization is deterministic.
levels = collections.OrderedDict()
for name_level in pairs:
name, level = name_level.split(':', 1)
name = name.strip()
level = level.strip()
levels[name] = level
return levels
class _LoggerLevelsSerializer(object):
"""Serializer for --logger_levels flag."""
def serialize(self, value):
if isinstance(value, str):
return value
return ','.join(
'{}:{}'.format(name, level) for name, level in value.items())
class _StderrthresholdFlag(flags.Flag):
"""Flag class for --stderrthreshold."""
def __init__(self, *args, **kwargs):
super(_StderrthresholdFlag, self).__init__(
flags.ArgumentParser(),
flags.ArgumentSerializer(),
*args, **kwargs)
@property
def value(self):
return self._value
@value.setter
def value(self, v):
if v in _CPP_LEVEL_TO_NAMES:
# --stderrthreshold also accepts numeric strings whose values are
# Abseil C++ log levels.
cpp_value = int(v)
v = _CPP_LEVEL_TO_NAMES[v] # Normalize to strings.
elif v.lower() in _CPP_NAME_TO_LEVELS:
v = v.lower()
if v == 'warn':
v = 'warning' # Use 'warning' as the canonical name.
cpp_value = int(_CPP_NAME_TO_LEVELS[v])
else:
raise ValueError(
'--stderrthreshold must be one of (case-insensitive) '
"'debug', 'info', 'warning', 'error', 'fatal', "
"or '0', '1', '2', '3', not '%s'" % v)
self._value = v
LOGTOSTDERR = flags.DEFINE_boolean(
'logtostderr',
False,
'Should only log to stderr?',
allow_override_cpp=True,
)
ALSOLOGTOSTDERR = flags.DEFINE_boolean(
'alsologtostderr',
False,
'also log to stderr?',
allow_override_cpp=True,
)
LOG_DIR = flags.DEFINE_string(
'log_dir',
os.getenv('TEST_TMPDIR', ''),
'directory to write logfiles into',
allow_override_cpp=True,
)
VERBOSITY = flags.DEFINE_flag(
_VerbosityFlag(
'verbosity',
-1,
(
'Logging verbosity level. Messages logged at this level or lower'
' will be included. Set to 1 for debug logging. If the flag was not'
' set or supplied, the value will be changed from the default of -1'
' (warning) to 0 (info) after flags are parsed.'
),
short_name='v',
allow_hide_cpp=True,
)
)
LOGGER_LEVELS = flags.DEFINE_flag(
_LoggerLevelsFlag(
'logger_levels',
{},
(
'Specify log level of loggers. The format is a CSV list of '
'`name:level`. Where `name` is the logger name used with '
'`logging.getLogger()`, and `level` is a level name (INFO, DEBUG, '
'etc). e.g. `myapp.foo:INFO,other.logger:DEBUG`'
),
)
)
STDERRTHRESHOLD = flags.DEFINE_flag(
_StderrthresholdFlag(
'stderrthreshold',
'fatal',
(
'log messages at this level, or more severe, to stderr in '
'addition to the logfile. Possible values are '
"'debug', 'info', 'warning', 'error', and 'fatal'. "
'Obsoletes --alsologtostderr. Using --alsologtostderr '
'cancels the effect of this flag. Please also note that '
'this flag is subject to --verbosity and requires logfile '
'not be stderr.'
),
allow_hide_cpp=True,
)
)
SHOWPREFIXFORINFO = flags.DEFINE_boolean(
'showprefixforinfo',
True,
(
'If False, do not prepend prefix to info messages '
"when it's logged to stderr, "
'--verbosity is set to INFO level, '
'and python logging is used.'
),
)
def get_verbosity():
"""Returns the logging verbosity."""
return FLAGS['verbosity'].value
def set_verbosity(v):
"""Sets the logging verbosity.
Causes all messages of level <= v to be logged,
and all messages of level > v to be silently discarded.
Args:
v: int|str, the verbosity level as an integer or string. Legal string values
are those that can be coerced to an integer as well as case-insensitive
'debug', 'info', 'warning', 'error', and 'fatal'.
"""
try:
new_level = int(v)
except ValueError:
new_level = converter.ABSL_NAMES[v.upper()]
FLAGS.verbosity = new_level
def set_stderrthreshold(s):
"""Sets the stderr threshold to the value passed in.
Args:
s: str|int, valid strings values are case-insensitive 'debug',
'info', 'warning', 'error', and 'fatal'; valid integer values are
logging.DEBUG|INFO|WARNING|ERROR|FATAL.
Raises:
ValueError: Raised when s is an invalid value.
"""
if s in converter.ABSL_LEVELS:
FLAGS.stderrthreshold = converter.ABSL_LEVELS[s]
elif isinstance(s, str) and s.upper() in converter.ABSL_NAMES:
FLAGS.stderrthreshold = s
else:
raise ValueError(
'set_stderrthreshold only accepts integer absl logging level '
'from -3 to 1, or case-insensitive string values '
"'debug', 'info', 'warning', 'error', and 'fatal'. "
'But found "{}" ({}).'.format(s, type(s)))
def fatal(msg, *args, **kwargs):
# type: (Any, Any, Any) -> NoReturn
"""Logs a fatal message."""
log(FATAL, msg, *args, **kwargs)
def error(msg, *args, **kwargs):
"""Logs an error message."""
log(ERROR, msg, *args, **kwargs)
def warning(msg, *args, **kwargs):
"""Logs a warning message."""
log(WARNING, msg, *args, **kwargs)
def warn(msg, *args, **kwargs):
"""Deprecated, use 'warning' instead."""
warnings.warn("The 'warn' function is deprecated, use 'warning' instead",
DeprecationWarning, 2)
log(WARNING, msg, *args, **kwargs)
def info(msg, *args, **kwargs):
"""Logs an info message."""
log(INFO, msg, *args, **kwargs)
def debug(msg, *args, **kwargs):
"""Logs a debug message."""
log(DEBUG, msg, *args, **kwargs)
def exception(msg, *args, exc_info=True, **kwargs):
"""Logs an exception, with traceback and message."""
error(msg, *args, exc_info=exc_info, **kwargs)
# Counter to keep track of number of log entries per token.
_log_counter_per_token = {}
def _get_next_log_count_per_token(token):
"""Wrapper for _log_counter_per_token. Thread-safe.
Args:
token: The token for which to look up the count.
Returns:
The number of times this function has been called with
*token* as an argument (starting at 0).
"""
# Can't use a defaultdict because defaultdict isn't atomic, whereas
# setdefault is.
return next(_log_counter_per_token.setdefault(token, itertools.count()))
def log_every_n(level, msg, n, *args):
"""Logs ``msg % args`` at level 'level' once per 'n' times.
Logs the 1st call, (N+1)st call, (2N+1)st call, etc.
Not threadsafe.
Args:
level: int, the absl logging level at which to log.
msg: str, the message to be logged.
n: int, the number of times this should be called before it is logged.
*args: The args to be substituted into the msg.
"""
count = _get_next_log_count_per_token(get_absl_logger().findCaller())
log_if(level, msg, not (count % n), *args)
# Keeps track of the last log time of the given token.
# Note: must be a dict since set/get is atomic in CPython.
# Note: entries are never released as their number is expected to be low.
_log_timer_per_token = {}
def _seconds_have_elapsed(token, num_seconds):
"""Tests if 'num_seconds' have passed since 'token' was requested.
Not strictly thread-safe - may log with the wrong frequency if called
concurrently from multiple threads. Accuracy depends on resolution of
'timeit.default_timer()'.
Always returns True on the first call for a given 'token'.
Args:
token: The token for which to look up the count.
num_seconds: The number of seconds to test for.
Returns:
Whether it has been >= 'num_seconds' since 'token' was last requested.
"""
now = timeit.default_timer()
then = _log_timer_per_token.get(token, None)
if then is None or (now - then) >= num_seconds:
_log_timer_per_token[token] = now
return True
else:
return False
def log_every_n_seconds(level, msg, n_seconds, *args):
"""Logs ``msg % args`` at level ``level`` iff ``n_seconds`` elapsed since last call.
Logs the first call, logs subsequent calls if 'n' seconds have elapsed since
the last logging call from the same call site (file + line). Not thread-safe.
Args:
level: int, the absl logging level at which to log.
msg: str, the message to be logged.
n_seconds: float or int, seconds which should elapse before logging again.
*args: The args to be substituted into the msg.
"""
should_log = _seconds_have_elapsed(get_absl_logger().findCaller(), n_seconds)
log_if(level, msg, should_log, *args)
def log_first_n(level, msg, n, *args):
"""Logs ``msg % args`` at level ``level`` only first ``n`` times.
Not threadsafe.
Args:
level: int, the absl logging level at which to log.
msg: str, the message to be logged.
n: int, the maximal number of times the message is logged.
*args: The args to be substituted into the msg.
"""
count = _get_next_log_count_per_token(get_absl_logger().findCaller())
log_if(level, msg, count < n, *args)
def log_if(level, msg, condition, *args):
"""Logs ``msg % args`` at level ``level`` only if condition is fulfilled."""
if condition:
log(level, msg, *args)
def log(level, msg, *args, **kwargs):
"""Logs ``msg % args`` at absl logging level ``level``.
If no args are given just print msg, ignoring any interpolation specifiers.
Args:
level: int, the absl logging level at which to log the message
(logging.DEBUG|INFO|WARNING|ERROR|FATAL). While some C++ verbose logging
level constants are also supported, callers should prefer explicit
logging.vlog() calls for such purpose.
msg: str, the message to be logged.
*args: The args to be substituted into the msg.
**kwargs: May contain exc_info to add exception traceback to message.
"""
if level > converter.ABSL_DEBUG:
# Even though this function supports level that is greater than 1, users
# should use logging.vlog instead for such cases.
# Treat this as vlog, 1 is equivalent to DEBUG.
standard_level = converter.STANDARD_DEBUG - (level - 1)
else:
if level < converter.ABSL_FATAL:
level = converter.ABSL_FATAL
standard_level = converter.absl_to_standard(level)
# Match standard logging's behavior. Before use_absl_handler() and
# logging is configured, there is no handler attached on _absl_logger nor
# logging.root. So logs go no where.
if not logging.root.handlers:
logging.basicConfig()
_absl_logger.log(standard_level, msg, *args, **kwargs)
def vlog(level, msg, *args, **kwargs):
"""Log ``msg % args`` at C++ vlog level ``level``.
Args:
level: int, the C++ verbose logging level at which to log the message,
e.g. 1, 2, 3, 4... While absl level constants are also supported,
callers should prefer logging.log|debug|info|... calls for such purpose.
msg: str, the message to be logged.
*args: The args to be substituted into the msg.
**kwargs: May contain exc_info to add exception traceback to message.
"""
log(level, msg, *args, **kwargs)
def vlog_is_on(level):
"""Checks if vlog is enabled for the given level in caller's source file.
Args:
level: int, the C++ verbose logging level at which to log the message,
e.g. 1, 2, 3, 4... While absl level constants are also supported,
callers should prefer level_debug|level_info|... calls for
checking those.
Returns:
True if logging is turned on for that level.
"""
if level > converter.ABSL_DEBUG:
# Even though this function supports level that is greater than 1, users
# should use logging.vlog instead for such cases.
# Treat this as vlog, 1 is equivalent to DEBUG.
standard_level = converter.STANDARD_DEBUG - (level - 1)
else:
if level < converter.ABSL_FATAL:
level = converter.ABSL_FATAL
standard_level = converter.absl_to_standard(level)
return _absl_logger.isEnabledFor(standard_level)
def flush():
"""Flushes all log files."""
get_absl_handler().flush()
def level_debug():
"""Returns True if debug logging is turned on."""
return get_verbosity() >= DEBUG
def level_info():
"""Returns True if info logging is turned on."""
return get_verbosity() >= INFO
def level_warning():
"""Returns True if warning logging is turned on."""
return get_verbosity() >= WARNING
level_warn = level_warning # Deprecated function.
def level_error():
"""Returns True if error logging is turned on."""
return get_verbosity() >= ERROR
def get_log_file_name(level=INFO):
"""Returns the name of the log file.
For Python logging, only one file is used and level is ignored. And it returns
empty string if it logs to stderr/stdout or the log stream has no `name`
attribute.
Args:
level: int, the absl.logging level.
Raises:
ValueError: Raised when `level` has an invalid value.
"""
if level not in converter.ABSL_LEVELS:
raise ValueError('Invalid absl.logging level {}'.format(level))
stream = get_absl_handler().python_handler.stream
if (stream == sys.stderr or stream == sys.stdout or
not hasattr(stream, 'name')):
return ''
else:
return stream.name
def find_log_dir_and_names(program_name=None, log_dir=None):
"""Computes the directory and filename prefix for log file.
Args:
program_name: str|None, the filename part of the path to the program that
is running without its extension. e.g: if your program is called
``usr/bin/foobar.py`` this method should probably be called with
``program_name='foobar`` However, this is just a convention, you can
pass in any string you want, and it will be used as part of the
log filename. If you don't pass in anything, the default behavior
is as described in the example. In python standard logging mode,
the program_name will be prepended with ``py_`` if it is the
``program_name`` argument is omitted.
log_dir: str|None, the desired log directory.
Returns:
(log_dir, file_prefix, symlink_prefix)
Raises:
FileNotFoundError: raised in Python 3 when it cannot find a log directory.
OSError: raised in Python 2 when it cannot find a log directory.
"""
if not program_name:
# Strip the extension (foobar.par becomes foobar, and
# fubar.py becomes fubar). We do this so that the log
# file names are similar to C++ log file names.
program_name = os.path.splitext(os.path.basename(sys.argv[0]))[0]
# Prepend py_ to files so that python code gets a unique file, and
# so that C++ libraries do not try to write to the same log files as us.
program_name = 'py_%s' % program_name
actual_log_dir = find_log_dir(log_dir=log_dir)
try:
username = getpass.getuser()
except KeyError:
# This can happen, e.g. when running under docker w/o passwd file.
if hasattr(os, 'getuid'):
# Windows doesn't have os.getuid
username = str(os.getuid())
else:
username = 'unknown'
hostname = socket.gethostname()
file_prefix = '%s.%s.%s.log' % (program_name, hostname, username)
return actual_log_dir, file_prefix, program_name
def find_log_dir(log_dir=None):
"""Returns the most suitable directory to put log files into.
Args:
log_dir: str|None, if specified, the logfile(s) will be created in that
directory. Otherwise if the --log_dir command-line flag is provided,
the logfile will be created in that directory. Otherwise the logfile
will be created in a standard location.
Raises:
FileNotFoundError: raised in Python 3 when it cannot find a log directory.
OSError: raised in Python 2 when it cannot find a log directory.
"""
# Get a list of possible log dirs (will try to use them in order).
# NOTE: Google's internal implementation has a special handling for Google
# machines, which uses a list of directories. Hence the following uses `dirs`
# instead of a single directory.
if log_dir:
# log_dir was explicitly specified as an arg, so use it and it alone.
dirs = [log_dir]
elif FLAGS['log_dir'].value:
# log_dir flag was provided, so use it and it alone (this mimics the
# behavior of the same flag in logging.cc).
dirs = [FLAGS['log_dir'].value]
else:
dirs = [tempfile.gettempdir()]
# Find the first usable log dir.
for d in dirs:
if os.path.isdir(d) and os.access(d, os.W_OK):
return d
raise FileNotFoundError(
"Can't find a writable directory for logs, tried %s" % dirs)
def get_absl_log_prefix(record):
"""Returns the absl log prefix for the log record.
Args:
record: logging.LogRecord, the record to get prefix for.
"""
created_tuple = time.localtime(record.created)
created_microsecond = int(record.created % 1.0 * 1e6)
critical_prefix = ''
level = record.levelno
if _is_non_absl_fatal_record(record):
# When the level is FATAL, but not logged from absl, lower the level so
# it's treated as ERROR.
level = logging.ERROR
critical_prefix = _CRITICAL_PREFIX
severity = converter.get_initial_for_level(level)
return '%c%02d%02d %02d:%02d:%02d.%06d %5d %s:%d] %s' % (
severity,
created_tuple.tm_mon,
created_tuple.tm_mday,
created_tuple.tm_hour,
created_tuple.tm_min,
created_tuple.tm_sec,
created_microsecond,
_get_thread_id(),
record.filename,
record.lineno,
critical_prefix)
def skip_log_prefix(func):
"""Skips reporting the prefix of a given function or name by :class:`~absl.logging.ABSLLogger`.
This is a convenience wrapper function / decorator for
:meth:`~absl.logging.ABSLLogger.register_frame_to_skip`.
If a callable function is provided, only that function will be skipped.
If a function name is provided, all functions with the same name in the
file that this is called in will be skipped.
This can be used as a decorator of the intended function to be skipped.
Args:
func: Callable function or its name as a string.
Returns:
func (the input, unchanged).
Raises:
ValueError: The input is callable but does not have a function code object.
TypeError: The input is neither callable nor a string.
"""
if callable(func):
func_code = getattr(func, '__code__', None)
if func_code is None:
raise ValueError('Input callable does not have a function code object.')
file_name = func_code.co_filename
func_name = func_code.co_name
func_lineno = func_code.co_firstlineno
elif isinstance(func, str):
file_name = get_absl_logger().findCaller()[0]
func_name = func
func_lineno = None
else:
raise TypeError('Input is neither callable nor a string.')
ABSLLogger.register_frame_to_skip(file_name, func_name, func_lineno)
return func
def _is_non_absl_fatal_record(log_record):
return (log_record.levelno >= logging.FATAL and
not log_record.__dict__.get(_ABSL_LOG_FATAL, False))
def _is_absl_fatal_record(log_record):
return (log_record.levelno >= logging.FATAL and
log_record.__dict__.get(_ABSL_LOG_FATAL, False))
# Indicates if we still need to warn about pre-init logs going to stderr.
_warn_preinit_stderr = True
class PythonHandler(logging.StreamHandler):
"""The handler class used by Abseil Python logging implementation."""
def __init__(self, stream=None, formatter=None):
super(PythonHandler, self).__init__(stream)
self.setFormatter(formatter or PythonFormatter())
def start_logging_to_file(self, program_name=None, log_dir=None):
"""Starts logging messages to files instead of standard error."""
FLAGS.logtostderr = False
actual_log_dir, file_prefix, symlink_prefix = find_log_dir_and_names(
program_name=program_name, log_dir=log_dir)
basename = '%s.INFO.%s.%d' % (
file_prefix,
time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time())),
os.getpid())
filename = os.path.join(actual_log_dir, basename)
self.stream = open(filename, 'a', encoding='utf-8')
# os.symlink is not available on Windows Python 2.
if getattr(os, 'symlink', None):
# Create a symlink to the log file with a canonical name.
symlink = os.path.join(actual_log_dir, symlink_prefix + '.INFO')
try:
if os.path.islink(symlink):
os.unlink(symlink)
os.symlink(os.path.basename(filename), symlink)
except EnvironmentError:
# If it fails, we're sad but it's no error. Commonly, this
# fails because the symlink was created by another user and so
# we can't modify it
pass
def use_absl_log_file(self, program_name=None, log_dir=None):
"""Conditionally logs to files, based on --logtostderr."""
if FLAGS['logtostderr'].value:
self.stream = sys.stderr
else:
self.start_logging_to_file(program_name=program_name, log_dir=log_dir)
def flush(self):
"""Flushes all log files."""
self.acquire()
try:
if self.stream and hasattr(self.stream, 'flush'):
self.stream.flush()
except (EnvironmentError, ValueError):
# A ValueError is thrown if we try to flush a closed file.
pass
finally:
self.release()
def _log_to_stderr(self, record):
"""Emits the record to stderr.
This temporarily sets the handler stream to stderr, calls
StreamHandler.emit, then reverts the stream back.
Args:
record: logging.LogRecord, the record to log.
"""
# emit() is protected by a lock in logging.Handler, so we don't need to
# protect here again.
old_stream = self.stream
self.stream = sys.stderr
try:
super(PythonHandler, self).emit(record)
finally:
self.stream = old_stream
def emit(self, record):
"""Prints a record out to some streams.
1. If ``FLAGS.logtostderr`` is set, it will print to ``sys.stderr`` ONLY.
2. If ``FLAGS.alsologtostderr`` is set, it will print to ``sys.stderr``.
3. If ``FLAGS.logtostderr`` is not set, it will log to the stream
associated with the current thread.
Args:
record: :class:`logging.LogRecord`, the record to emit.
"""
# People occasionally call logging functions at import time before
# our flags may have even been defined yet, let alone even parsed, as we
# rely on the C++ side to define some flags for us and app init to
# deal with parsing. Match the C++ library behavior of notify and emit
# such messages to stderr. It encourages people to clean-up and does
# not hide the message.
level = record.levelno
if not FLAGS.is_parsed(): # Also implies "before flag has been defined".
global _warn_preinit_stderr
if _warn_preinit_stderr:
sys.stderr.write(
'WARNING: Logging before flag parsing goes to stderr.\n')
_warn_preinit_stderr = False
self._log_to_stderr(record)
elif FLAGS['logtostderr'].value:
self._log_to_stderr(record)
else:
super(PythonHandler, self).emit(record)
stderr_threshold = converter.string_to_standard(
FLAGS['stderrthreshold'].value)
if ((FLAGS['alsologtostderr'].value or level >= stderr_threshold) and
self.stream != sys.stderr):
self._log_to_stderr(record)
# Die when the record is created from ABSLLogger and level is FATAL.
if _is_absl_fatal_record(record):
self.flush() # Flush the log before dying.
# In threaded python, sys.exit() from a non-main thread only
# exits the thread in question.
os.abort()
def close(self):
"""Closes the stream to which we are writing."""
self.acquire()
try:
self.flush()
try:
# Do not close the stream if it's sys.stderr|stdout. They may be
# redirected or overridden to files, which should be managed by users
# explicitly.
user_managed = sys.stderr, sys.stdout, sys.__stderr__, sys.__stdout__
if self.stream not in user_managed and (
not hasattr(self.stream, 'isatty') or not self.stream.isatty()):
self.stream.close()
except ValueError:
# A ValueError is thrown if we try to run isatty() on a closed file.
pass
super(PythonHandler, self).close()
finally:
self.release()
class ABSLHandler(logging.Handler):
"""Abseil Python logging module's log handler."""
def __init__(self, python_logging_formatter):
super(ABSLHandler, self).__init__()
self._python_handler = PythonHandler(formatter=python_logging_formatter)
self.activate_python_handler()
def format(self, record):
return self._current_handler.format(record)
def setFormatter(self, fmt):
self._current_handler.setFormatter(fmt)
def emit(self, record):
self._current_handler.emit(record)
def flush(self):
self._current_handler.flush()
def close(self):
super(ABSLHandler, self).close()
self._current_handler.close()
def handle(self, record):
rv = self.filter(record)
if rv:
return self._current_handler.handle(record)
return rv
@property
def python_handler(self):
return self._python_handler
def activate_python_handler(self):
"""Uses the Python logging handler as the current logging handler."""
self._current_handler = self._python_handler
def use_absl_log_file(self, program_name=None, log_dir=None):
self._current_handler.use_absl_log_file(program_name, log_dir)
def start_logging_to_file(self, program_name=None, log_dir=None):
self._current_handler.start_logging_to_file(program_name, log_dir)
class PythonFormatter(logging.Formatter):
"""Formatter class used by :class:`~absl.logging.PythonHandler`."""
def format(self, record):
"""Appends the message from the record to the results of the prefix.
Args:
record: logging.LogRecord, the record to be formatted.
Returns:
The formatted string representing the record.
"""
if (not FLAGS['showprefixforinfo'].value and
FLAGS['verbosity'].value == converter.ABSL_INFO and
record.levelno == logging.INFO and
_absl_handler.python_handler.stream == sys.stderr):
prefix = ''
else:
prefix = get_absl_log_prefix(record)
return prefix + super(PythonFormatter, self).format(record)
class ABSLLogger(logging.getLoggerClass()):
"""A logger that will create LogRecords while skipping some stack frames.
This class maintains an internal list of filenames and method names
for use when determining who called the currently executing stack
frame. Any method names from specific source files are skipped when
walking backwards through the stack.
Client code should use the register_frame_to_skip method to let the
ABSLLogger know which method from which file should be
excluded from the walk backwards through the stack.
"""
_frames_to_skip = set()
def findCaller(self, stack_info=False, stacklevel=1):
"""Finds the frame of the calling method on the stack.
This method skips any frames registered with the
ABSLLogger and any methods from this file, and whatever
method is currently being used to generate the prefix for the log
line. Then it returns the file name, line number, and method name
of the calling method. An optional fourth item may be returned,
callers who only need things from the first three are advised to
always slice or index the result rather than using direct unpacking
assignment.
Args:
stack_info: bool, when True, include the stack trace as a fourth item
returned. On Python 3 there are always four items returned - the
fourth will be None when this is False. On Python 2 the stdlib
base class API only returns three items. We do the same when this
new parameter is unspecified or False for compatibility.
Returns:
(filename, lineno, methodname[, sinfo]) of the calling method.
"""
f_to_skip = ABSLLogger._frames_to_skip
# Use sys._getframe(2) instead of logging.currentframe(), it's slightly
# faster because there is one less frame to traverse.
frame = sys._getframe(2) # pylint: disable=protected-access
while frame:
code = frame.f_code
if (_LOGGING_FILE_PREFIX not in code.co_filename and
(code.co_filename, code.co_name,
code.co_firstlineno) not in f_to_skip and
(code.co_filename, code.co_name) not in f_to_skip):
sinfo = None
if stack_info:
out = io.StringIO()
out.write(u'Stack (most recent call last):\n')
traceback.print_stack(frame, file=out)
sinfo = out.getvalue().rstrip(u'\n')
return (code.co_filename, frame.f_lineno, code.co_name, sinfo)
frame = frame.f_back
def critical(self, msg, *args, **kwargs):
"""Logs ``msg % args`` with severity ``CRITICAL``."""
self.log(logging.CRITICAL, msg, *args, **kwargs)
def fatal(self, msg, *args, **kwargs):
"""Logs ``msg % args`` with severity ``FATAL``."""
self.log(logging.FATAL, msg, *args, **kwargs)
def error(self, msg, *args, **kwargs):
"""Logs ``msg % args`` with severity ``ERROR``."""
self.log(logging.ERROR, msg, *args, **kwargs)
def warn(self, msg, *args, **kwargs):
"""Logs ``msg % args`` with severity ``WARN``."""
warnings.warn("The 'warn' method is deprecated, use 'warning' instead",
DeprecationWarning, 2)
self.log(logging.WARN, msg, *args, **kwargs)
def warning(self, msg, *args, **kwargs):
"""Logs ``msg % args`` with severity ``WARNING``."""
self.log(logging.WARNING, msg, *args, **kwargs)
def info(self, msg, *args, **kwargs):
"""Logs ``msg % args`` with severity ``INFO``."""
self.log(logging.INFO, msg, *args, **kwargs)
def debug(self, msg, *args, **kwargs):
"""Logs ``msg % args`` with severity ``DEBUG``."""
self.log(logging.DEBUG, msg, *args, **kwargs)
def log(self, level, msg, *args, **kwargs):
"""Logs a message at a cetain level substituting in the supplied arguments.
This method behaves differently in python and c++ modes.
Args:
level: int, the standard logging level at which to log the message.
msg: str, the text of the message to log.
*args: The arguments to substitute in the message.
**kwargs: The keyword arguments to substitute in the message.
"""
if level >= logging.FATAL:
# Add property to the LogRecord created by this logger.
# This will be used by the ABSLHandler to determine whether it should
# treat CRITICAL/FATAL logs as really FATAL.
extra = kwargs.setdefault('extra', {})
extra[_ABSL_LOG_FATAL] = True
super(ABSLLogger, self).log(level, msg, *args, **kwargs)
def handle(self, record):
"""Calls handlers without checking ``Logger.disabled``.
Non-root loggers are set to disabled after setup with :func:`logging.config`
if it's not explicitly specified. Historically, absl logging will not be
disabled by that. To maintaining this behavior, this function skips
checking the ``Logger.disabled`` bit.
This logger can still be disabled by adding a filter that filters out
everything.
Args:
record: logging.LogRecord, the record to handle.
"""
if self.filter(record):
self.callHandlers(record)
@classmethod
def register_frame_to_skip(cls, file_name, function_name, line_number=None):
"""Registers a function name to skip when walking the stack.
The :class:`~absl.logging.ABSLLogger` sometimes skips method calls on the
stack to make the log messages meaningful in their appropriate context.
This method registers a function from a particular file as one
which should be skipped.
Args:
file_name: str, the name of the file that contains the function.
function_name: str, the name of the function to skip.
line_number: int, if provided, only the function with this starting line
number will be skipped. Otherwise, all functions with the same name
in the file will be skipped.
"""
if line_number is not None:
cls._frames_to_skip.add((file_name, function_name, line_number))
else:
cls._frames_to_skip.add((file_name, function_name))
def _get_thread_id():
"""Gets id of current thread, suitable for logging as an unsigned quantity.
If pywrapbase is linked, returns GetTID() for the thread ID to be
consistent with C++ logging. Otherwise, returns the numeric thread id.
The quantities are made unsigned by masking with 2*sys.maxint + 1.
Returns:
Thread ID unique to this process (unsigned)
"""
thread_id = threading.get_ident()
return thread_id & _THREAD_ID_MASK
def get_absl_logger():
"""Returns the absl logger instance."""
assert _absl_logger is not None
return _absl_logger
def get_absl_handler():
"""Returns the absl handler instance."""
assert _absl_handler is not None
return _absl_handler
def use_python_logging(quiet=False):
"""Uses the python implementation of the logging code.
Args:
quiet: No logging message about switching logging type.
"""
get_absl_handler().activate_python_handler()
if not quiet:
info('Restoring pure python logging')
_attempted_to_remove_stderr_stream_handlers = False
def use_absl_handler():
"""Uses the ABSL logging handler for logging.
This method is called in :func:`app.run()` so the absl handler
is used in absl apps.
"""
global _attempted_to_remove_stderr_stream_handlers
if not _attempted_to_remove_stderr_stream_handlers:
# The absl handler logs to stderr by default. To prevent double logging to
# stderr, the following code tries its best to remove other handlers that
# emit to stderr. Those handlers are most commonly added when
# logging.info/debug is called before calling use_absl_handler().
handlers = [
h for h in logging.root.handlers
if isinstance(h, logging.StreamHandler) and h.stream == sys.stderr]
for h in handlers:
logging.root.removeHandler(h)
_attempted_to_remove_stderr_stream_handlers = True
absl_handler = get_absl_handler()
if absl_handler not in logging.root.handlers:
logging.root.addHandler(absl_handler)
FLAGS['verbosity']._update_logging_levels() # pylint: disable=protected-access
FLAGS['logger_levels']._update_logger_levels() # pylint: disable=protected-access
def _initialize():
"""Initializes loggers and handlers."""
global _absl_logger, _absl_handler
if _absl_logger:
return
original_logger_class = logging.getLoggerClass()
logging.setLoggerClass(ABSLLogger)
_absl_logger = logging.getLogger('absl')
logging.setLoggerClass(original_logger_class)
python_logging_formatter = PythonFormatter()
_absl_handler = ABSLHandler(python_logging_formatter)
_initialize()
abseil-py-2.1.0/absl/logging/__init__.pyi 0000664 0000000 0000000 00000013242 14551576331 0020244 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, NoReturn, Optional, Tuple, TypeVar, Union
from absl import flags
# Logging levels.
FATAL: int
ERROR: int
WARNING: int
WARN: int # Deprecated name.
INFO: int
DEBUG: int
ABSL_LOGGING_PREFIX_REGEX: str
LOGTOSTDERR: flags.FlagHolder[bool]
ALSOLOGTOSTDERR: flags.FlagHolder[bool]
LOG_DIR: flags.FlagHolder[str]
VERBOSITY: flags.FlagHolder[int]
LOGGER_LEVELS: flags.FlagHolder[Dict[str, str]]
STDERRTHRESHOLD: flags.FlagHolder[str]
SHOWPREFIXFORINFO: flags.FlagHolder[bool]
def get_verbosity() -> int:
...
def set_verbosity(v: Union[int, str]) -> None:
...
def set_stderrthreshold(s: Union[int, str]) -> None:
...
# TODO(b/277607978): Provide actual args+kwargs shadowing stdlib's logging functions.
def fatal(msg: Any, *args: Any, **kwargs: Any) -> NoReturn:
...
def error(msg: Any, *args: Any, **kwargs: Any) -> None:
...
def warning(msg: Any, *args: Any, **kwargs: Any) -> None:
...
def warn(msg: Any, *args: Any, **kwargs: Any) -> None:
...
def info(msg: Any, *args: Any, **kwargs: Any) -> None:
...
def debug(msg: Any, *args: Any, **kwargs: Any) -> None:
...
def exception(msg: Any, *args: Any, **kwargs: Any) -> None:
...
def log_every_n(level: int, msg: Any, n: int, *args: Any) -> None:
...
def log_every_n_seconds(
level: int, msg: Any, n_seconds: float, *args: Any
) -> None:
...
def log_first_n(level: int, msg: Any, n: int, *args: Any) -> None:
...
def log_if(level: int, msg: Any, condition: Any, *args: Any) -> None:
...
def log(level: int, msg: Any, *args: Any, **kwargs: Any) -> None:
...
def vlog(level: int, msg: Any, *args: Any, **kwargs: Any) -> None:
...
def vlog_is_on(level: int) -> bool:
...
def flush() -> None:
...
def level_debug() -> bool:
...
def level_info() -> bool:
...
def level_warning() -> bool:
...
level_warn = level_warning # Deprecated function.
def level_error() -> bool:
...
def get_log_file_name(level: int = ...) -> str:
...
def find_log_dir_and_names(
program_name: Optional[str] = ..., log_dir: Optional[str] = ...
) -> Tuple[str, str, str]:
...
def find_log_dir(log_dir: Optional[str] = ...) -> str:
...
def get_absl_log_prefix(record: logging.LogRecord) -> str:
...
_SkipLogT = TypeVar('_SkipLogT', str, Callable[..., Any])
def skip_log_prefix(func: _SkipLogT) -> _SkipLogT:
...
_StreamT = TypeVar("_StreamT")
class PythonHandler(logging.StreamHandler[_StreamT]):
def __init__(
self,
stream: Optional[_StreamT] = ...,
formatter: Optional[logging.Formatter] = ...,
) -> None:
...
def start_logging_to_file(
self, program_name: Optional[str] = ..., log_dir: Optional[str] = ...
) -> None:
...
def use_absl_log_file(
self, program_name: Optional[str] = ..., log_dir: Optional[str] = ...
) -> None:
...
def flush(self) -> None:
...
def emit(self, record: logging.LogRecord) -> None:
...
def close(self) -> None:
...
class ABSLHandler(logging.Handler):
def __init__(self, python_logging_formatter: PythonFormatter) -> None:
...
def format(self, record: logging.LogRecord) -> str:
...
def setFormatter(self, fmt) -> None:
...
def emit(self, record: logging.LogRecord) -> None:
...
def flush(self) -> None:
...
def close(self) -> None:
...
def handle(self, record: logging.LogRecord) -> bool:
...
@property
def python_handler(self) -> PythonHandler:
...
def activate_python_handler(self) -> None:
...
def use_absl_log_file(
self, program_name: Optional[str] = ..., log_dir: Optional[str] = ...
) -> None:
...
def start_logging_to_file(self, program_name=None, log_dir=None) -> None:
...
class PythonFormatter(logging.Formatter):
def format(self, record: logging.LogRecord) -> str:
...
class ABSLLogger(logging.Logger):
def findCaller(
self, stack_info: bool = ..., stacklevel: int = ...
) -> Tuple[str, int, str, Optional[str]]:
...
def critical(self, msg: Any, *args: Any, **kwargs: Any) -> None:
...
def fatal(self, msg: Any, *args: Any, **kwargs: Any) -> NoReturn:
...
def error(self, msg: Any, *args: Any, **kwargs: Any) -> None:
...
def warn(self, msg: Any, *args: Any, **kwargs: Any) -> None:
...
def warning(self, msg: Any, *args: Any, **kwargs: Any) -> None:
...
def info(self, msg: Any, *args: Any, **kwargs: Any) -> None:
...
def debug(self, msg: Any, *args: Any, **kwargs: Any) -> None:
...
def log(self, level: int, msg: Any, *args: Any, **kwargs: Any) -> None:
...
def handle(self, record: logging.LogRecord) -> None:
...
@classmethod
def register_frame_to_skip(
cls, file_name: str, function_name: str, line_number: Optional[int] = ...
) -> None:
...
# NOTE: Returns None before _initialize called but shouldn't occur after import.
def get_absl_logger() -> ABSLLogger:
...
# NOTE: Returns None before _initialize called but shouldn't occur after import.
def get_absl_handler() -> ABSLHandler:
...
def use_python_logging(quiet: bool = ...) -> None:
...
def use_absl_handler() -> None:
...
abseil-py-2.1.0/absl/logging/converter.py 0000664 0000000 0000000 00000014321 14551576331 0020342 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module to convert log levels between Abseil Python, C++, and Python standard.
This converter has to convert (best effort) between three different
logging level schemes:
* **cpp**: The C++ logging level scheme used in Abseil C++.
* **absl**: The absl.logging level scheme used in Abseil Python.
* **standard**: The python standard library logging level scheme.
Here is a handy ascii chart for easy mental mapping::
LEVEL | cpp | absl | standard |
---------+-----+--------+----------+
DEBUG | 0 | 1 | 10 |
INFO | 0 | 0 | 20 |
WARNING | 1 | -1 | 30 |
ERROR | 2 | -2 | 40 |
CRITICAL | 3 | -3 | 50 |
FATAL | 3 | -3 | 50 |
Note: standard logging ``CRITICAL`` is mapped to absl/cpp ``FATAL``.
However, only ``CRITICAL`` logs from the absl logger (or absl.logging.fatal)
will terminate the program. ``CRITICAL`` logs from non-absl loggers are treated
as error logs with a message prefix ``"CRITICAL - "``.
Converting from standard to absl or cpp is a lossy conversion.
Converting back to standard will lose granularity. For this reason,
users should always try to convert to standard, the richest
representation, before manipulating the levels, and then only to cpp
or absl if those level schemes are absolutely necessary.
"""
import logging
STANDARD_CRITICAL = logging.CRITICAL
STANDARD_ERROR = logging.ERROR
STANDARD_WARNING = logging.WARNING
STANDARD_INFO = logging.INFO
STANDARD_DEBUG = logging.DEBUG
# These levels are also used to define the constants
# FATAL, ERROR, WARNING, INFO, and DEBUG in the
# absl.logging module.
ABSL_FATAL = -3
ABSL_ERROR = -2
ABSL_WARNING = -1
ABSL_WARN = -1 # Deprecated name.
ABSL_INFO = 0
ABSL_DEBUG = 1
ABSL_LEVELS = {ABSL_FATAL: 'FATAL',
ABSL_ERROR: 'ERROR',
ABSL_WARNING: 'WARNING',
ABSL_INFO: 'INFO',
ABSL_DEBUG: 'DEBUG'}
# Inverts the ABSL_LEVELS dictionary
ABSL_NAMES = {'FATAL': ABSL_FATAL,
'ERROR': ABSL_ERROR,
'WARNING': ABSL_WARNING,
'WARN': ABSL_WARNING, # Deprecated name.
'INFO': ABSL_INFO,
'DEBUG': ABSL_DEBUG}
ABSL_TO_STANDARD = {ABSL_FATAL: STANDARD_CRITICAL,
ABSL_ERROR: STANDARD_ERROR,
ABSL_WARNING: STANDARD_WARNING,
ABSL_INFO: STANDARD_INFO,
ABSL_DEBUG: STANDARD_DEBUG}
# Inverts the ABSL_TO_STANDARD
STANDARD_TO_ABSL = dict((v, k) for (k, v) in ABSL_TO_STANDARD.items())
def get_initial_for_level(level):
"""Gets the initial that should start the log line for the given level.
It returns:
* ``'I'`` when: ``level < STANDARD_WARNING``.
* ``'W'`` when: ``STANDARD_WARNING <= level < STANDARD_ERROR``.
* ``'E'`` when: ``STANDARD_ERROR <= level < STANDARD_CRITICAL``.
* ``'F'`` when: ``level >= STANDARD_CRITICAL``.
Args:
level: int, a Python standard logging level.
Returns:
The first initial as it would be logged by the C++ logging module.
"""
if level < STANDARD_WARNING:
return 'I'
elif level < STANDARD_ERROR:
return 'W'
elif level < STANDARD_CRITICAL:
return 'E'
else:
return 'F'
def absl_to_cpp(level):
"""Converts an absl log level to a cpp log level.
Args:
level: int, an absl.logging level.
Raises:
TypeError: Raised when level is not an integer.
Returns:
The corresponding integer level for use in Abseil C++.
"""
if not isinstance(level, int):
raise TypeError('Expect an int level, found {}'.format(type(level)))
if level >= 0:
# C++ log levels must be >= 0
return 0
else:
return -level
def absl_to_standard(level):
"""Converts an integer level from the absl value to the standard value.
Args:
level: int, an absl.logging level.
Raises:
TypeError: Raised when level is not an integer.
Returns:
The corresponding integer level for use in standard logging.
"""
if not isinstance(level, int):
raise TypeError('Expect an int level, found {}'.format(type(level)))
if level < ABSL_FATAL:
level = ABSL_FATAL
if level <= ABSL_DEBUG:
return ABSL_TO_STANDARD[level]
# Maps to vlog levels.
return STANDARD_DEBUG - level + 1
def string_to_standard(level):
"""Converts a string level to standard logging level value.
Args:
level: str, case-insensitive ``'debug'``, ``'info'``, ``'warning'``,
``'error'``, ``'fatal'``.
Returns:
The corresponding integer level for use in standard logging.
"""
return absl_to_standard(ABSL_NAMES.get(level.upper()))
def standard_to_absl(level):
"""Converts an integer level from the standard value to the absl value.
Args:
level: int, a Python standard logging level.
Raises:
TypeError: Raised when level is not an integer.
Returns:
The corresponding integer level for use in absl logging.
"""
if not isinstance(level, int):
raise TypeError('Expect an int level, found {}'.format(type(level)))
if level < 0:
level = 0
if level < STANDARD_DEBUG:
# Maps to vlog levels.
return STANDARD_DEBUG - level + 1
elif level < STANDARD_INFO:
return ABSL_DEBUG
elif level < STANDARD_WARNING:
return ABSL_INFO
elif level < STANDARD_ERROR:
return ABSL_WARNING
elif level < STANDARD_CRITICAL:
return ABSL_ERROR
else:
return ABSL_FATAL
def standard_to_cpp(level):
"""Converts an integer level from the standard value to the cpp value.
Args:
level: int, a Python standard logging level.
Raises:
TypeError: Raised when level is not an integer.
Returns:
The corresponding integer level for use in cpp logging.
"""
return absl_to_cpp(standard_to_absl(level))
abseil-py-2.1.0/absl/logging/tests/ 0000775 0000000 0000000 00000000000 14551576331 0017122 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/absl/logging/tests/__init__.py 0000664 0000000 0000000 00000001110 14551576331 0021224 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
abseil-py-2.1.0/absl/logging/tests/converter_test.py 0000664 0000000 0000000 00000013405 14551576331 0022545 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for converter.py."""
import logging
from absl import logging as absl_logging
from absl.logging import converter
from absl.testing import absltest
class ConverterTest(absltest.TestCase):
"""Tests the converter module."""
def test_absl_to_cpp(self):
self.assertEqual(0, converter.absl_to_cpp(absl_logging.DEBUG))
self.assertEqual(0, converter.absl_to_cpp(absl_logging.INFO))
self.assertEqual(1, converter.absl_to_cpp(absl_logging.WARN))
self.assertEqual(2, converter.absl_to_cpp(absl_logging.ERROR))
self.assertEqual(3, converter.absl_to_cpp(absl_logging.FATAL))
with self.assertRaises(TypeError):
converter.absl_to_cpp('')
def test_absl_to_standard(self):
self.assertEqual(
logging.DEBUG, converter.absl_to_standard(absl_logging.DEBUG))
self.assertEqual(
logging.INFO, converter.absl_to_standard(absl_logging.INFO))
self.assertEqual(
logging.WARNING, converter.absl_to_standard(absl_logging.WARN))
self.assertEqual(
logging.WARN, converter.absl_to_standard(absl_logging.WARN))
self.assertEqual(
logging.ERROR, converter.absl_to_standard(absl_logging.ERROR))
self.assertEqual(
logging.FATAL, converter.absl_to_standard(absl_logging.FATAL))
self.assertEqual(
logging.CRITICAL, converter.absl_to_standard(absl_logging.FATAL))
# vlog levels.
self.assertEqual(9, converter.absl_to_standard(2))
self.assertEqual(8, converter.absl_to_standard(3))
with self.assertRaises(TypeError):
converter.absl_to_standard('')
def test_standard_to_absl(self):
self.assertEqual(
absl_logging.DEBUG, converter.standard_to_absl(logging.DEBUG))
self.assertEqual(
absl_logging.INFO, converter.standard_to_absl(logging.INFO))
self.assertEqual(
absl_logging.WARN, converter.standard_to_absl(logging.WARN))
self.assertEqual(
absl_logging.WARN, converter.standard_to_absl(logging.WARNING))
self.assertEqual(
absl_logging.ERROR, converter.standard_to_absl(logging.ERROR))
self.assertEqual(
absl_logging.FATAL, converter.standard_to_absl(logging.FATAL))
self.assertEqual(
absl_logging.FATAL, converter.standard_to_absl(logging.CRITICAL))
# vlog levels.
self.assertEqual(2, converter.standard_to_absl(logging.DEBUG - 1))
self.assertEqual(3, converter.standard_to_absl(logging.DEBUG - 2))
with self.assertRaises(TypeError):
converter.standard_to_absl('')
def test_standard_to_cpp(self):
self.assertEqual(0, converter.standard_to_cpp(logging.DEBUG))
self.assertEqual(0, converter.standard_to_cpp(logging.INFO))
self.assertEqual(1, converter.standard_to_cpp(logging.WARN))
self.assertEqual(1, converter.standard_to_cpp(logging.WARNING))
self.assertEqual(2, converter.standard_to_cpp(logging.ERROR))
self.assertEqual(3, converter.standard_to_cpp(logging.FATAL))
self.assertEqual(3, converter.standard_to_cpp(logging.CRITICAL))
with self.assertRaises(TypeError):
converter.standard_to_cpp('')
def test_get_initial_for_level(self):
self.assertEqual('F', converter.get_initial_for_level(logging.CRITICAL))
self.assertEqual('E', converter.get_initial_for_level(logging.ERROR))
self.assertEqual('W', converter.get_initial_for_level(logging.WARNING))
self.assertEqual('I', converter.get_initial_for_level(logging.INFO))
self.assertEqual('I', converter.get_initial_for_level(logging.DEBUG))
self.assertEqual('I', converter.get_initial_for_level(logging.NOTSET))
self.assertEqual('F', converter.get_initial_for_level(51))
self.assertEqual('E', converter.get_initial_for_level(49))
self.assertEqual('E', converter.get_initial_for_level(41))
self.assertEqual('W', converter.get_initial_for_level(39))
self.assertEqual('W', converter.get_initial_for_level(31))
self.assertEqual('I', converter.get_initial_for_level(29))
self.assertEqual('I', converter.get_initial_for_level(21))
self.assertEqual('I', converter.get_initial_for_level(19))
self.assertEqual('I', converter.get_initial_for_level(11))
self.assertEqual('I', converter.get_initial_for_level(9))
self.assertEqual('I', converter.get_initial_for_level(1))
self.assertEqual('I', converter.get_initial_for_level(-1))
def test_string_to_standard(self):
self.assertEqual(logging.DEBUG, converter.string_to_standard('debug'))
self.assertEqual(logging.INFO, converter.string_to_standard('info'))
self.assertEqual(logging.WARNING, converter.string_to_standard('warn'))
self.assertEqual(logging.WARNING, converter.string_to_standard('warning'))
self.assertEqual(logging.ERROR, converter.string_to_standard('error'))
self.assertEqual(logging.CRITICAL, converter.string_to_standard('fatal'))
self.assertEqual(logging.DEBUG, converter.string_to_standard('DEBUG'))
self.assertEqual(logging.INFO, converter.string_to_standard('INFO'))
self.assertEqual(logging.WARNING, converter.string_to_standard('WARN'))
self.assertEqual(logging.WARNING, converter.string_to_standard('WARNING'))
self.assertEqual(logging.ERROR, converter.string_to_standard('ERROR'))
self.assertEqual(logging.CRITICAL, converter.string_to_standard('FATAL'))
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/logging/tests/log_before_import_test.py 0000664 0000000 0000000 00000011570 14551576331 0024234 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test of logging behavior before app.run(), aka flag and logging init()."""
import contextlib
import io
import os
import re
import sys
import tempfile
from unittest import mock
from absl import logging
from absl.testing import absltest
logging.get_verbosity() # Access --verbosity before flag parsing.
# Access --logtostderr before flag parsing.
logging.get_absl_handler().use_absl_log_file()
class Error(Exception):
pass
@contextlib.contextmanager
def captured_stderr_filename():
"""Captures stderr and writes them to a temporary file.
This uses os.dup/os.dup2 to redirect the stderr fd for capturing standard
error of logging at import-time. We cannot mock sys.stderr because on the
first log call, a default log handler writing to the mock sys.stderr is
registered, and it will never be removed and subsequent logs go to the mock
in addition to the real stder.
Yields:
The filename of captured stderr.
"""
stderr_capture_file_fd, stderr_capture_file_name = tempfile.mkstemp()
original_stderr_fd = os.dup(sys.stderr.fileno())
os.dup2(stderr_capture_file_fd, sys.stderr.fileno())
try:
yield stderr_capture_file_name
finally:
os.close(stderr_capture_file_fd)
os.dup2(original_stderr_fd, sys.stderr.fileno())
# Pre-initialization (aka "import" / __main__ time) test.
with captured_stderr_filename() as before_set_verbosity_filename:
# Warnings and above go to stderr.
logging.debug('Debug message at parse time.')
logging.info('Info message at parse time.')
logging.error('Error message at parse time.')
logging.warning('Warning message at parse time.')
try:
raise Error('Exception reason.')
except Error:
logging.exception('Exception message at parse time.')
logging.set_verbosity(logging.ERROR)
with captured_stderr_filename() as after_set_verbosity_filename:
# Verbosity is set to ERROR, errors and above go to stderr.
logging.debug('Debug message at parse time.')
logging.info('Info message at parse time.')
logging.warning('Warning message at parse time.')
logging.error('Error message at parse time.')
class LoggingInitWarningTest(absltest.TestCase):
def test_captured_pre_init_warnings(self):
with open(before_set_verbosity_filename) as stderr_capture_file:
captured_stderr = stderr_capture_file.read()
self.assertNotIn('Debug message at parse time.', captured_stderr)
self.assertNotIn('Info message at parse time.', captured_stderr)
traceback_re = re.compile(
r'\nTraceback \(most recent call last\):.*?Error: Exception reason.',
re.MULTILINE | re.DOTALL)
if not traceback_re.search(captured_stderr):
self.fail(
'Cannot find traceback message from logging.exception '
'in stderr:\n{}'.format(captured_stderr))
# Remove the traceback so the rest of the stderr is deterministic.
captured_stderr = traceback_re.sub('', captured_stderr)
captured_stderr_lines = captured_stderr.splitlines()
self.assertLen(captured_stderr_lines, 3)
self.assertIn('Error message at parse time.', captured_stderr_lines[0])
self.assertIn('Warning message at parse time.', captured_stderr_lines[1])
self.assertIn('Exception message at parse time.', captured_stderr_lines[2])
def test_set_verbosity_pre_init(self):
with open(after_set_verbosity_filename) as stderr_capture_file:
captured_stderr = stderr_capture_file.read()
captured_stderr_lines = captured_stderr.splitlines()
self.assertNotIn('Debug message at parse time.', captured_stderr)
self.assertNotIn('Info message at parse time.', captured_stderr)
self.assertNotIn('Warning message at parse time.', captured_stderr)
self.assertLen(captured_stderr_lines, 1)
self.assertIn('Error message at parse time.', captured_stderr_lines[0])
def test_no_more_warnings(self):
fake_stderr_type = io.BytesIO if bytes is str else io.StringIO
with mock.patch('sys.stderr', new=fake_stderr_type()) as mock_stderr:
self.assertMultiLineEqual('', mock_stderr.getvalue())
logging.warning('Hello. hello. hello. Is there anybody out there?')
self.assertNotIn('Logging before flag parsing goes to stderr',
mock_stderr.getvalue())
logging.info('A major purpose of this executable is merely not to crash.')
if __name__ == '__main__':
absltest.main() # This calls the app.run() init equivalent.
abseil-py-2.1.0/absl/logging/tests/logging_functional_test.py 0000664 0000000 0000000 00000064445 14551576331 0024420 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functional tests for absl.logging."""
import fnmatch
import os
import re
import shutil
import subprocess
import sys
import tempfile
from absl import logging
from absl.testing import _bazelize_command
from absl.testing import absltest
from absl.testing import parameterized
_PY_VLOG3_LOG_MESSAGE = """\
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:62] This line is VLOG level 3
"""
_PY_VLOG2_LOG_MESSAGE = """\
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:64] This line is VLOG level 2
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:64] This line is log level 2
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:64] VLOG level 1, but only if VLOG level 2 is active
"""
# VLOG1 is the same as DEBUG logs.
_PY_DEBUG_LOG_MESSAGE = """\
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level 1
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level 1
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:66] This line is DEBUG
"""
_PY_INFO_LOG_MESSAGE = """\
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level 0
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level 0
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:70] Interesting Stuff\0
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:71] Interesting Stuff with Arguments: 42
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:73] Interesting Stuff with Dictionary
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:123] This should appear 5 times.
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:76] Info first 1 of 2
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:77] Info 1 (every 3)
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:76] Info first 2 of 2
I1231 23:59:59.000000 12345 logging_functional_test_helper.py:77] Info 4 (every 3)
"""
_PY_INFO_LOG_MESSAGE_NOPREFIX = """\
This line is VLOG level 0
This line is log level 0
Interesting Stuff\0
Interesting Stuff with Arguments: 42
Interesting Stuff with Dictionary
This should appear 5 times.
This should appear 5 times.
This should appear 5 times.
This should appear 5 times.
This should appear 5 times.
Info first 1 of 2
Info 1 (every 3)
Info first 2 of 2
Info 4 (every 3)
"""
_PY_WARNING_LOG_MESSAGE = """\
W1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level -1
W1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level -1
W1231 23:59:59.000000 12345 logging_functional_test_helper.py:79] Worrying Stuff
W0000 23:59:59.000000 12345 logging_functional_test_helper.py:81] Warn first 1 of 2
W0000 23:59:59.000000 12345 logging_functional_test_helper.py:82] Warn 1 (every 3)
W0000 23:59:59.000000 12345 logging_functional_test_helper.py:81] Warn first 2 of 2
W0000 23:59:59.000000 12345 logging_functional_test_helper.py:82] Warn 4 (every 3)
"""
if sys.version_info[0:2] == (3, 4):
_FAKE_ERROR_EXTRA_MESSAGE = """\
Traceback (most recent call last):
File "logging_functional_test_helper.py", line 456, in _test_do_logging
raise OSError('Fake Error')
"""
else:
_FAKE_ERROR_EXTRA_MESSAGE = ''
_PY_ERROR_LOG_MESSAGE = """\
E1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is VLOG level -2
E1231 23:59:59.000000 12345 logging_functional_test_helper.py:65] This line is log level -2
E1231 23:59:59.000000 12345 logging_functional_test_helper.py:87] An Exception %s
Traceback (most recent call last):
File "logging_functional_test_helper.py", line 456, in _test_do_logging
raise OSError('Fake Error')
OSError: Fake Error
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Once more, just because
Traceback (most recent call last):
File "./logging_functional_test_helper.py", line 78, in _test_do_logging
raise OSError('Fake Error')
OSError: Fake Error
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Exception 2 %s
Traceback (most recent call last):
File "logging_functional_test_helper.py", line 456, in _test_do_logging
raise OSError('Fake Error')
OSError: Fake Error
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Non-exception
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] Exception 3
Traceback (most recent call last):
File "logging_functional_test_helper.py", line 456, in _test_do_logging
raise OSError('Fake Error')
OSError: Fake Error
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] No traceback
{fake_error_extra}OSError: Fake Error
E1231 23:59:59.000000 12345 logging_functional_test_helper.py:90] Alarming Stuff
E0000 23:59:59.000000 12345 logging_functional_test_helper.py:92] Error first 1 of 2
E0000 23:59:59.000000 12345 logging_functional_test_helper.py:93] Error 1 (every 3)
E0000 23:59:59.000000 12345 logging_functional_test_helper.py:92] Error first 2 of 2
E0000 23:59:59.000000 12345 logging_functional_test_helper.py:93] Error 4 (every 3)
""".format(fake_error_extra=_FAKE_ERROR_EXTRA_MESSAGE)
_CRITICAL_DOWNGRADE_TO_ERROR_MESSAGE = """\
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] CRITICAL - A critical message
"""
_VERBOSITY_FLAG_TEST_PARAMETERS = (
('fatal', logging.FATAL),
('error', logging.ERROR),
('warning', logging.WARN),
('info', logging.INFO),
('debug', logging.DEBUG),
('vlog1', 1),
('vlog2', 2),
('vlog3', 3))
def _get_fatal_log_expectation(testcase, message, include_stacktrace):
"""Returns the expectation for fatal logging tests.
Args:
testcase: The TestCase instance.
message: The extra fatal logging message.
include_stacktrace: Whether or not to include stacktrace.
Returns:
A callable, the expectation for fatal logging tests. It will be passed to
FunctionalTest._exec_test as third items in the expected_logs list.
See _exec_test's docstring for more information.
"""
def assert_logs(logs):
if os.name == 'nt':
# On Windows, it also dumps extra information at the end, something like:
# This application has requested the Runtime to terminate it in an
# unusual way. Please contact the application's support team for more
# information.
logs = '\n'.join(logs.split('\n')[:-3])
format_string = (
'F1231 23:59:59.000000 12345 logging_functional_test_helper.py:175] '
'%s message\n')
expected_logs = format_string % message
if include_stacktrace:
expected_logs += 'Stack trace:\n'
faulthandler_start = 'Fatal Python error: Aborted'
testcase.assertIn(faulthandler_start, logs)
log_message = logs.split(faulthandler_start)[0]
testcase.assertEqual(_munge_log(expected_logs), _munge_log(log_message))
return assert_logs
def _munge_log(buf):
"""Remove timestamps, thread ids, filenames and line numbers from logs."""
# Remove all messages produced before the output to be tested.
buf = re.sub(r'(?:.|\n)*START OF TEST HELPER LOGS: IGNORE PREVIOUS.\n',
r'',
buf)
# Greeting
buf = re.sub(r'(?m)^Log file created at: .*\n',
'',
buf)
buf = re.sub(r'(?m)^Running on machine: .*\n',
'',
buf)
buf = re.sub(r'(?m)^Binary: .*\n',
'',
buf)
buf = re.sub(r'(?m)^Log line format: .*\n',
'',
buf)
# Verify thread id is logged as a non-negative quantity.
matched = re.match(r'(?m)^(\w)(\d\d\d\d \d\d:\d\d:\d\d\.\d\d\d\d\d\d) '
r'([ ]*-?[0-9a-fA-f]+ )?([a-zA-Z<][\w._<>-]+):(\d+)',
buf)
if matched:
threadid = matched.group(3)
if int(threadid) < 0:
raise AssertionError("Negative threadid '%s' in '%s'" % (threadid, buf))
# Timestamp
buf = re.sub(r'(?m)' + logging.ABSL_LOGGING_PREFIX_REGEX,
r'\g0000 00:00:00.000000 12345 \g:123',
buf)
# Traceback
buf = re.sub(r'(?m)^ File "(.*/)?([^"/]+)", line (\d+),',
r' File "\g<2>", line 456,',
buf)
# Stack trace is too complicated for re, just assume it extends to end of
# output
buf = re.sub(r'(?sm)^Stack trace:\n.*',
r'Stack trace:\n',
buf)
buf = re.sub(r'(?sm)^\*\*\* Signal 6 received by PID.*\n.*',
r'Stack trace:\n',
buf)
buf = re.sub((r'(?sm)^\*\*\* ([A-Z]+) received by PID (\d+) '
r'\(TID 0x([0-9a-f]+)\)'
r'( from PID \d+)?; stack trace: \*\*\*\n.*'),
r'Stack trace:\n',
buf)
buf = re.sub(r'(?sm)^\*\*\* Check failure stack trace: \*\*\*\n.*',
r'Stack trace:\n',
buf)
if os.name == 'nt':
# On windows, we calls Python interpreter explicitly, so the file names
# include the full path. Strip them.
buf = re.sub(r'( File ").*(logging_functional_test_helper\.py", line )',
r'\1\2',
buf)
return buf
def _verify_status(expected, actual, output):
if expected != actual:
raise AssertionError(
'Test exited with unexpected status code %d (expected %d). '
'Output was:\n%s' % (actual, expected, output))
def _verify_ok(status, output):
"""Check that helper exited with no errors."""
_verify_status(0, status, output)
def _verify_fatal(status, output):
"""Check that helper died as expected."""
# os.abort generates a SIGABRT signal (-6). On Windows, the process
# immediately returns an exit code of 3.
# See https://docs.python.org/3.6/library/os.html#os.abort.
expected_exit_code = 3 if os.name == 'nt' else -6
_verify_status(expected_exit_code, status, output)
def _verify_assert(status, output):
""".Check that helper failed with assertion."""
_verify_status(1, status, output)
class FunctionalTest(parameterized.TestCase):
"""Functional tests using the logging_functional_test_helper script."""
def _get_helper(self):
helper_name = 'absl/logging/tests/logging_functional_test_helper'
return _bazelize_command.get_executable_path(helper_name)
def _get_logs(self,
verbosity,
include_info_prefix=True):
logs = []
if verbosity >= 3:
logs.append(_PY_VLOG3_LOG_MESSAGE)
if verbosity >= 2:
logs.append(_PY_VLOG2_LOG_MESSAGE)
if verbosity >= logging.DEBUG:
logs.append(_PY_DEBUG_LOG_MESSAGE)
if verbosity >= logging.INFO:
if include_info_prefix:
logs.append(_PY_INFO_LOG_MESSAGE)
else:
logs.append(_PY_INFO_LOG_MESSAGE_NOPREFIX)
if verbosity >= logging.WARN:
logs.append(_PY_WARNING_LOG_MESSAGE)
if verbosity >= logging.ERROR:
logs.append(_PY_ERROR_LOG_MESSAGE)
expected_logs = ''.join(logs)
expected_logs = expected_logs.replace(
"", "")
return expected_logs
def setUp(self):
super(FunctionalTest, self).setUp()
self._log_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
def tearDown(self):
shutil.rmtree(self._log_dir)
super(FunctionalTest, self).tearDown()
def _exec_test(self,
verify_exit_fn,
expected_logs,
test_name='do_logging',
pass_logtostderr=False,
use_absl_log_file=False,
show_info_prefix=1,
call_dict_config=False,
extra_args=()):
"""Execute the helper script and verify its output.
Args:
verify_exit_fn: A function taking (status, output).
expected_logs: List of tuples, or None if output shouldn't be checked.
Tuple is (log prefix, log type, expected contents):
- log prefix: A program name, or 'stderr'.
- log type: 'INFO', 'ERROR', etc.
- expected: Can be the following:
- A string
- A callable, called with the logs as a single argument
- None, means don't check contents of log file
test_name: Name to pass to helper.
pass_logtostderr: Pass --logtostderr to the helper script if True.
use_absl_log_file: If True, call
logging.get_absl_handler().use_absl_log_file() before test_fn in
logging_functional_test_helper.
show_info_prefix: --showprefixforinfo value passed to the helper script.
call_dict_config: True if helper script should call
logging.config.dictConfig.
extra_args: Iterable of str (optional, defaults to ()) - extra arguments
to pass to the helper script.
Raises:
AssertionError: Assertion error when test fails.
"""
args = ['--log_dir=%s' % self._log_dir]
if pass_logtostderr:
args.append('--logtostderr')
if not show_info_prefix:
args.append('--noshowprefixforinfo')
args += extra_args
# Execute helper in subprocess.
env = os.environ.copy()
env.update({
'TEST_NAME': test_name,
'USE_ABSL_LOG_FILE': '%d' % (use_absl_log_file,),
'CALL_DICT_CONFIG': '%d' % (call_dict_config,),
})
cmd = [self._get_helper()] + args
print('env: %s' % env, file=sys.stderr)
print('cmd: %s' % cmd, file=sys.stderr)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=env,
universal_newlines=True)
output, _ = process.communicate()
status = process.returncode
# Verify exit status.
verify_exit_fn(status, output)
# Check outputs?
if expected_logs is None:
return
# Get list of log files.
logs = os.listdir(self._log_dir)
logs = fnmatch.filter(logs, '*.log.*')
logs.append('stderr')
# Look for a log matching each expected pattern.
matched = []
unmatched = []
unexpected = logs[:]
for log_prefix, log_type, expected in expected_logs:
# What pattern?
if log_prefix == 'stderr':
assert log_type is None
pattern = 'stderr'
else:
pattern = r'%s[.].*[.]log[.]%s[.][\d.-]*$' % (log_prefix, log_type)
# Is it there
for basename in logs:
if re.match(pattern, basename):
matched.append([expected, basename])
unexpected.remove(basename)
break
else:
unmatched.append(pattern)
# Mismatch?
errors = ''
if unmatched:
errors += 'The following log files were expected but not found: %s' % (
'\n '.join(unmatched))
if unexpected:
if errors:
errors += '\n'
errors += 'The following log files were not expected: %s' % (
'\n '.join(unexpected))
if errors:
raise AssertionError(errors)
# Compare contents of matches.
for (expected, basename) in matched:
if expected is None:
continue
if basename == 'stderr':
actual = output
else:
path = os.path.join(self._log_dir, basename)
with open(path, encoding='utf-8') as f:
actual = f.read()
if callable(expected):
try:
expected(actual)
except AssertionError:
print('expected_logs assertion failed, actual {} log:\n{}'.format(
basename, actual), file=sys.stderr)
raise
elif isinstance(expected, str):
self.assertMultiLineEqual(_munge_log(expected), _munge_log(actual),
'%s differs' % basename)
else:
self.fail(
'Invalid value found for expected logs: {}, type: {}'.format(
expected, type(expected)))
@parameterized.named_parameters(
('', False),
('logtostderr', True))
def test_py_logging(self, logtostderr):
# Python logging by default logs to stderr.
self._exec_test(
_verify_ok,
[['stderr', None, self._get_logs(logging.INFO)]],
pass_logtostderr=logtostderr)
def test_py_logging_use_absl_log_file(self):
# Python logging calling use_absl_log_file causes also log to files.
self._exec_test(
_verify_ok,
[['stderr', None, ''],
['absl_log_file', 'INFO', self._get_logs(logging.INFO)]],
use_absl_log_file=True)
def test_py_logging_use_absl_log_file_logtostderr(self):
# Python logging asked to log to stderr even though use_absl_log_file
# is called.
self._exec_test(
_verify_ok,
[['stderr', None, self._get_logs(logging.INFO)]],
pass_logtostderr=True,
use_absl_log_file=True)
@parameterized.named_parameters(
('', False),
('logtostderr', True))
def test_py_logging_noshowprefixforinfo(self, logtostderr):
self._exec_test(
_verify_ok,
[['stderr', None, self._get_logs(logging.INFO,
include_info_prefix=False)]],
pass_logtostderr=logtostderr,
show_info_prefix=0)
def test_py_logging_noshowprefixforinfo_use_absl_log_file(self):
self._exec_test(
_verify_ok,
[['stderr', None, ''],
['absl_log_file', 'INFO', self._get_logs(logging.INFO)]],
show_info_prefix=0,
use_absl_log_file=True)
def test_py_logging_noshowprefixforinfo_use_absl_log_file_logtostderr(self):
self._exec_test(
_verify_ok,
[['stderr', None, self._get_logs(logging.INFO,
include_info_prefix=False)]],
pass_logtostderr=True,
show_info_prefix=0,
use_absl_log_file=True)
def test_py_logging_noshowprefixforinfo_verbosity(self):
self._exec_test(
_verify_ok,
[['stderr', None, self._get_logs(logging.DEBUG)]],
pass_logtostderr=True,
show_info_prefix=0,
use_absl_log_file=True,
extra_args=['-v=1'])
def test_py_logging_fatal_main_thread_only(self):
self._exec_test(
_verify_fatal,
[['stderr', None, _get_fatal_log_expectation(
self, 'fatal_main_thread_only', False)]],
test_name='fatal_main_thread_only')
def test_py_logging_fatal_with_other_threads(self):
self._exec_test(
_verify_fatal,
[['stderr', None, _get_fatal_log_expectation(
self, 'fatal_with_other_threads', False)]],
test_name='fatal_with_other_threads')
def test_py_logging_fatal_non_main_thread(self):
self._exec_test(
_verify_fatal,
[['stderr', None, _get_fatal_log_expectation(
self, 'fatal_non_main_thread', False)]],
test_name='fatal_non_main_thread')
def test_py_logging_critical_non_absl(self):
self._exec_test(
_verify_ok,
[['stderr', None, _CRITICAL_DOWNGRADE_TO_ERROR_MESSAGE]],
test_name='critical_from_non_absl_logger')
def test_py_logging_skip_log_prefix(self):
self._exec_test(
_verify_ok,
[['stderr', None, '']],
test_name='register_frame_to_skip')
def test_py_logging_flush(self):
self._exec_test(
_verify_ok,
[['stderr', None, '']],
test_name='flush')
@parameterized.named_parameters(*_VERBOSITY_FLAG_TEST_PARAMETERS)
def test_py_logging_verbosity_stderr(self, verbosity):
"""Tests -v/--verbosity flag with python logging to stderr."""
v_flag = '-v=%d' % verbosity
self._exec_test(
_verify_ok,
[['stderr', None, self._get_logs(verbosity)]],
extra_args=[v_flag])
@parameterized.named_parameters(*_VERBOSITY_FLAG_TEST_PARAMETERS)
def test_py_logging_verbosity_file(self, verbosity):
"""Tests -v/--verbosity flag with Python logging to stderr."""
v_flag = '-v=%d' % verbosity
self._exec_test(
_verify_ok,
[['stderr', None, ''],
# When using python logging, it only creates a file named INFO,
# unlike C++ it also creates WARNING and ERROR files.
['absl_log_file', 'INFO', self._get_logs(verbosity)]],
use_absl_log_file=True,
extra_args=[v_flag])
def test_stderrthreshold_py_logging(self):
"""Tests --stderrthreshold."""
stderr_logs = '''\
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, debug log
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, info log
W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, warning log
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=debug, error log
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=info, info log
W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=info, warning log
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=info, error log
W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=warning, warning log
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=warning, error log
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] FLAGS.stderrthreshold=error, error log
'''
expected_logs = [
['stderr', None, stderr_logs],
['absl_log_file', 'INFO', None],
]
# Set verbosity to debug to test stderrthreshold == debug.
extra_args = ['-v=1']
self._exec_test(
_verify_ok,
expected_logs,
test_name='stderrthreshold',
extra_args=extra_args,
use_absl_log_file=True)
def test_std_logging_py_logging(self):
"""Tests logs from std logging."""
stderr_logs = '''\
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std debug log
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std info log
W0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std warning log
E0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] std error log
'''
expected_logs = [['stderr', None, stderr_logs]]
extra_args = ['-v=1', '--logtostderr']
self._exec_test(
_verify_ok,
expected_logs,
test_name='std_logging',
extra_args=extra_args)
def test_bad_exc_info_py_logging(self):
def assert_stderr(stderr):
# The exact message differs among different Python versions. So it just
# asserts some certain information is there.
self.assertIn('Traceback (most recent call last):', stderr)
self.assertIn('IndexError', stderr)
expected_logs = [
['stderr', None, assert_stderr],
['absl_log_file', 'INFO', '']]
self._exec_test(
_verify_ok,
expected_logs,
test_name='bad_exc_info',
use_absl_log_file=True)
def test_verbosity_logger_levels_flag_ordering(self):
"""Make sure last-specified flag wins."""
def assert_error_level_logged(stderr):
lines = stderr.splitlines()
for line in lines:
self.assertIn('std error log', line)
self._exec_test(
_verify_ok,
test_name='std_logging',
expected_logs=[('stderr', None, assert_error_level_logged)],
extra_args=['-v=1', '--logger_levels=:ERROR'])
def assert_debug_level_logged(stderr):
lines = stderr.splitlines()
for line in lines:
self.assertRegex(line, 'std (debug|info|warning|error) log')
self._exec_test(
_verify_ok,
test_name='std_logging',
expected_logs=[('stderr', None, assert_debug_level_logged)],
extra_args=['--logger_levels=:ERROR', '-v=1'])
def test_none_exc_info_py_logging(self):
expected_stderr = ''
expected_info = '''\
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] None exc_info
'''
expected_info += 'NoneType: None\n'
expected_logs = [
['stderr', None, expected_stderr],
['absl_log_file', 'INFO', expected_info]]
self._exec_test(
_verify_ok,
expected_logs,
test_name='none_exc_info',
use_absl_log_file=True)
def test_unicode_py_logging(self):
def get_stderr_message(stderr, name):
match = re.search(
'-- begin {} --\n(.*)-- end {} --'.format(name, name),
stderr, re.MULTILINE | re.DOTALL)
self.assertTrue(
match, 'Cannot find stderr message for test {}'.format(name))
return match.group(1)
def assert_stderr(stderr):
"""Verifies that it writes correct information to stderr for Python 3.
There are no unicode errors in Python 3.
Args:
stderr: the message from stderr.
"""
# Successful logs:
for name in (
'unicode', 'unicode % unicode', 'bytes % bytes', 'unicode % bytes',
'bytes % unicode', 'unicode % iso8859-15', 'str % exception',
'str % exception'):
logging.info('name = %s', name)
self.assertEqual('', get_stderr_message(stderr, name))
expected_logs = [['stderr', None, assert_stderr]]
info_log = u'''\
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: Ch\u00e2tonnaye
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: Ch\u00e2tonnaye
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] b'G\\xc3\\xaete: b'Ch\\xc3\\xa2tonnaye''
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: b'Ch\\xc3\\xa2tonnaye'
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] b'G\\xc3\\xaete: Ch\u00e2tonnaye'
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] G\u00eete: b'Ch\\xe2tonnaye'
I0000 00:00:00.000000 12345 logging_functional_test_helper.py:123] exception: Ch\u00e2tonnaye
'''
expected_logs.append(['absl_log_file', 'INFO', info_log])
self._exec_test(
_verify_ok,
expected_logs,
test_name='unicode',
use_absl_log_file=True)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/logging/tests/logging_functional_test_helper.py 0000664 0000000 0000000 00000022202 14551576331 0025740 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper script for logging_functional_test."""
import logging as std_logging
import logging.config as std_logging_config
import os
import sys
import threading
import time
import timeit
from unittest import mock
from absl import app
from absl import flags
from absl import logging
FLAGS = flags.FLAGS
class VerboseDel(object):
"""Dummy class to test __del__ running."""
def __init__(self, msg):
self._msg = msg
def __del__(self):
sys.stderr.write(self._msg)
sys.stderr.flush()
def _test_do_logging():
"""Do some log operations."""
logging.vlog(3, 'This line is VLOG level 3')
logging.vlog(2, 'This line is VLOG level 2')
logging.log(2, 'This line is log level 2')
if logging.vlog_is_on(2):
logging.log(1, 'VLOG level 1, but only if VLOG level 2 is active')
logging.vlog(1, 'This line is VLOG level 1')
logging.log(1, 'This line is log level 1')
logging.debug('This line is DEBUG')
logging.vlog(0, 'This line is VLOG level 0')
logging.log(0, 'This line is log level 0')
logging.info('Interesting Stuff\0')
logging.info('Interesting Stuff with Arguments: %d', 42)
logging.info('%(a)s Stuff with %(b)s',
{'a': 'Interesting', 'b': 'Dictionary'})
with mock.patch.object(timeit, 'default_timer') as mock_timer:
mock_timer.return_value = 0
while timeit.default_timer() < 9:
logging.log_every_n_seconds(logging.INFO, 'This should appear 5 times.',
2)
mock_timer.return_value = mock_timer() + .2
for i in range(1, 5):
logging.log_first_n(logging.INFO, 'Info first %d of %d', 2, i, 2)
logging.log_every_n(logging.INFO, 'Info %d (every %d)', 3, i, 3)
logging.vlog(-1, 'This line is VLOG level -1')
logging.log(-1, 'This line is log level -1')
logging.warning('Worrying Stuff')
for i in range(1, 5):
logging.log_first_n(logging.WARNING, 'Warn first %d of %d', 2, i, 2)
logging.log_every_n(logging.WARNING, 'Warn %d (every %d)', 3, i, 3)
logging.vlog(-2, 'This line is VLOG level -2')
logging.log(-2, 'This line is log level -2')
try:
raise OSError('Fake Error')
except OSError:
saved_exc_info = sys.exc_info()
logging.exception('An Exception %s')
logging.exception('Once more, %(reason)s', {'reason': 'just because'})
logging.error('Exception 2 %s', exc_info=True)
logging.error('Non-exception', exc_info=False)
try:
sys.exc_clear()
except AttributeError:
# No sys.exc_clear() in Python 3, but this will clear sys.exc_info() too.
pass
logging.error('Exception %s', '3', exc_info=saved_exc_info)
logging.error('No traceback', exc_info=saved_exc_info[:2] + (None,))
logging.error('Alarming Stuff')
for i in range(1, 5):
logging.log_first_n(logging.ERROR, 'Error first %d of %d', 2, i, 2)
logging.log_every_n(logging.ERROR, 'Error %d (every %d)', 3, i, 3)
logging.flush()
def _test_fatal_main_thread_only():
"""Test logging.fatal from main thread, no other threads running."""
v = VerboseDel('fatal_main_thread_only main del called\n')
try:
logging.fatal('fatal_main_thread_only message')
finally:
del v
def _test_fatal_with_other_threads():
"""Test logging.fatal from main thread, other threads running."""
lock = threading.Lock()
lock.acquire()
def sleep_forever(lock=lock):
v = VerboseDel('fatal_with_other_threads non-main del called\n')
try:
lock.release()
while True:
time.sleep(10000)
finally:
del v
v = VerboseDel('fatal_with_other_threads main del called\n')
try:
# Start new thread
t = threading.Thread(target=sleep_forever)
t.start()
# Wait for other thread
lock.acquire()
lock.release()
# Die
logging.fatal('fatal_with_other_threads message')
while True:
time.sleep(10000)
finally:
del v
def _test_fatal_non_main_thread():
"""Test logging.fatal from non main thread."""
lock = threading.Lock()
lock.acquire()
def die_soon(lock=lock):
v = VerboseDel('fatal_non_main_thread non-main del called\n')
try:
# Wait for signal from other thread
lock.acquire()
lock.release()
logging.fatal('fatal_non_main_thread message')
while True:
time.sleep(10000)
finally:
del v
v = VerboseDel('fatal_non_main_thread main del called\n')
try:
# Start new thread
t = threading.Thread(target=die_soon)
t.start()
# Signal other thread
lock.release()
# Wait for it to die
while True:
time.sleep(10000)
finally:
del v
def _test_critical_from_non_absl_logger():
"""Test CRITICAL logs from non-absl loggers."""
std_logging.critical('A critical message')
def _test_register_frame_to_skip():
"""Test skipping frames for line number reporting."""
def _getline():
def _getline_inner():
return logging.get_absl_logger().findCaller()[1]
return _getline_inner()
# Check register_frame_to_skip function to see if log frame skipping works.
line1 = _getline()
line2 = _getline()
logging.get_absl_logger().register_frame_to_skip(__file__, '_getline')
line3 = _getline()
# Both should be line number of the _getline_inner() call.
assert (line1 == line2), (line1, line2)
# line3 should be a line number in this function.
assert (line2 != line3), (line2, line3)
def _test_flush():
"""Test flush in various difficult cases."""
# Flush, but one of the logfiles is closed
log_filename = os.path.join(FLAGS.log_dir, 'a_thread_with_logfile.txt')
with open(log_filename, 'w') as log_file:
logging.get_absl_handler().python_handler.stream = log_file
logging.flush()
def _test_stderrthreshold():
"""Tests modifying --stderrthreshold after flag parsing will work."""
def log_things():
logging.debug('FLAGS.stderrthreshold=%s, debug log', FLAGS.stderrthreshold)
logging.info('FLAGS.stderrthreshold=%s, info log', FLAGS.stderrthreshold)
logging.warning('FLAGS.stderrthreshold=%s, warning log',
FLAGS.stderrthreshold)
logging.error('FLAGS.stderrthreshold=%s, error log', FLAGS.stderrthreshold)
FLAGS.stderrthreshold = 'debug'
log_things()
FLAGS.stderrthreshold = 'info'
log_things()
FLAGS.stderrthreshold = 'warning'
log_things()
FLAGS.stderrthreshold = 'error'
log_things()
def _test_std_logging():
"""Tests logs from std logging."""
std_logging.debug('std debug log')
std_logging.info('std info log')
std_logging.warning('std warning log')
std_logging.error('std error log')
def _test_bad_exc_info():
"""Tests when a bad exc_info valud is provided."""
logging.info('Bad exc_info', exc_info=(None, None))
def _test_none_exc_info():
"""Tests when exc_info is requested but not available."""
# Clear exc_info first.
try:
sys.exc_clear()
except AttributeError:
# No sys.exc_clear() in Python 3, but this will clear sys.exc_info() too.
pass
logging.info('None exc_info', exc_info=True)
def _test_unicode():
"""Tests unicode handling."""
test_names = []
def log(name, msg, *args):
"""Logs the message, and ensures the same name is not logged again."""
assert name not in test_names, ('test_unicode expects unique names to work,'
' found existing name {}').format(name)
test_names.append(name)
# Add line separators so that tests can verify the output for each log
# message.
sys.stderr.write('-- begin {} --\n'.format(name))
logging.info(msg, *args)
sys.stderr.write('-- end {} --\n'.format(name))
log('unicode', u'G\u00eete: Ch\u00e2tonnaye')
log('unicode % unicode', u'G\u00eete: %s', u'Ch\u00e2tonnaye')
log('bytes % bytes', u'G\u00eete: %s'.encode('utf-8'),
u'Ch\u00e2tonnaye'.encode('utf-8'))
log('unicode % bytes', u'G\u00eete: %s', u'Ch\u00e2tonnaye'.encode('utf-8'))
log('bytes % unicode', u'G\u00eete: %s'.encode('utf-8'), u'Ch\u00e2tonnaye')
log('unicode % iso8859-15', u'G\u00eete: %s',
u'Ch\u00e2tonnaye'.encode('iso-8859-15'))
log('str % exception', 'exception: %s', Exception(u'Ch\u00e2tonnaye'))
def main(argv):
del argv # Unused.
test_name = os.environ.get('TEST_NAME', None)
test_fn = globals().get('_test_%s' % test_name)
if test_fn is None:
raise AssertionError('TEST_NAME must be set to a valid value')
# Flush so previous messages are written to file before we switch to a new
# file with use_absl_log_file.
logging.flush()
if os.environ.get('USE_ABSL_LOG_FILE') == '1':
logging.get_absl_handler().use_absl_log_file('absl_log_file', FLAGS.log_dir)
test_fn()
if __name__ == '__main__':
sys.argv[0] = 'py_argv_0'
if os.environ.get('CALL_DICT_CONFIG') == '1':
std_logging_config.dictConfig({'version': 1})
app.run(main)
abseil-py-2.1.0/absl/logging/tests/logging_test.py 0000664 0000000 0000000 00000111044 14551576331 0022162 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for absl.logging."""
import contextlib
import functools
import getpass
import io
import logging as std_logging
import os
import re
import socket
import sys
import tempfile
import threading
import time
import traceback
import unittest
from unittest import mock
from absl import flags
from absl import logging
from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
FLAGS = flags.FLAGS
class ConfigurationTest(absltest.TestCase):
"""Tests the initial logging configuration."""
def test_logger_and_handler(self):
absl_logger = std_logging.getLogger('absl')
self.assertIs(absl_logger, logging.get_absl_logger())
self.assertIsInstance(absl_logger, logging.ABSLLogger)
self.assertIsInstance(
logging.get_absl_handler().python_handler.formatter,
logging.PythonFormatter)
class LoggerLevelsTest(parameterized.TestCase):
def setUp(self):
super(LoggerLevelsTest, self).setUp()
# Since these tests muck with the flag, always save/restore in case the
# tests forget to clean up properly.
# enter_context() is py3-only, but manually enter/exit should suffice.
cm = self.set_logger_levels({})
cm.__enter__()
self.addCleanup(lambda: cm.__exit__(None, None, None))
@contextlib.contextmanager
def set_logger_levels(self, levels):
original_levels = {
name: std_logging.getLogger(name).level for name in levels
}
try:
with flagsaver.flagsaver(logger_levels=levels):
yield
finally:
for name, level in original_levels.items():
std_logging.getLogger(name).setLevel(level)
def assert_logger_level(self, name, expected_level):
logger = std_logging.getLogger(name)
self.assertEqual(logger.level, expected_level)
def assert_logged(self, logger_name, expected_msgs):
logger = std_logging.getLogger(logger_name)
# NOTE: assertLogs() sets the logger to INFO if not specified.
with self.assertLogs(logger, logger.level) as cm:
logger.debug('debug')
logger.info('info')
logger.warning('warning')
logger.error('error')
logger.critical('critical')
actual = {r.getMessage() for r in cm.records}
self.assertEqual(set(expected_msgs), actual)
def test_setting_levels(self):
# Other tests change the root logging level, so we can't
# assume it's the default.
orig_root_level = std_logging.root.getEffectiveLevel()
with self.set_logger_levels({'foo': 'ERROR', 'bar': 'DEBUG'}):
self.assert_logger_level('foo', std_logging.ERROR)
self.assert_logger_level('bar', std_logging.DEBUG)
self.assert_logger_level('', orig_root_level)
self.assert_logged('foo', {'error', 'critical'})
self.assert_logged('bar',
{'debug', 'info', 'warning', 'error', 'critical'})
@parameterized.named_parameters(
('empty', ''),
('one_value', 'one:INFO'),
('two_values', 'one.a:INFO,two.b:ERROR'),
('whitespace_ignored', ' one : DEBUG , two : INFO'),
)
def test_serialize_parse(self, levels_str):
fl = FLAGS['logger_levels']
fl.parse(levels_str)
expected = levels_str.replace(' ', '')
actual = fl.serialize()
self.assertEqual('--logger_levels={}'.format(expected), actual)
def test_invalid_value(self):
with self.assertRaisesRegex(ValueError, 'Unknown level.*10'):
FLAGS['logger_levels'].parse('foo:10')
class PythonHandlerTest(absltest.TestCase):
"""Tests the PythonHandler class."""
def setUp(self):
super().setUp()
(year, month, day, hour, minute, sec,
dunno, dayofyear, dst_flag) = (1979, 10, 21, 18, 17, 16, 3, 15, 0)
self.now_tuple = (year, month, day, hour, minute, sec,
dunno, dayofyear, dst_flag)
self.python_handler = logging.PythonHandler()
def tearDown(self):
mock.patch.stopall()
super().tearDown()
@flagsaver.flagsaver(logtostderr=False)
def test_set_google_log_file_no_log_to_stderr(self):
with mock.patch.object(self.python_handler, 'start_logging_to_file'):
self.python_handler.use_absl_log_file()
self.python_handler.start_logging_to_file.assert_called_once_with(
program_name=None, log_dir=None)
@flagsaver.flagsaver(logtostderr=True)
def test_set_google_log_file_with_log_to_stderr(self):
self.python_handler.stream = None
self.python_handler.use_absl_log_file()
self.assertEqual(sys.stderr, self.python_handler.stream)
@mock.patch.object(logging, 'find_log_dir_and_names')
@mock.patch.object(logging.time, 'localtime')
@mock.patch.object(logging.time, 'time')
@mock.patch.object(os.path, 'islink')
@mock.patch.object(os, 'unlink')
@mock.patch.object(os, 'getpid')
def test_start_logging_to_file(
self, mock_getpid, mock_unlink, mock_islink, mock_time,
mock_localtime, mock_find_log_dir_and_names):
mock_find_log_dir_and_names.return_value = ('here', 'prog1', 'prog1')
mock_time.return_value = '12345'
mock_localtime.return_value = self.now_tuple
mock_getpid.return_value = 4321
symlink = os.path.join('here', 'prog1.INFO')
mock_islink.return_value = True
with mock.patch.object(
logging, 'open', return_value=sys.stdout, create=True):
if getattr(os, 'symlink', None):
with mock.patch.object(os, 'symlink'):
self.python_handler.start_logging_to_file()
mock_unlink.assert_called_once_with(symlink)
os.symlink.assert_called_once_with(
'prog1.INFO.19791021-181716.4321', symlink)
else:
self.python_handler.start_logging_to_file()
def test_log_file(self):
handler = logging.PythonHandler()
self.assertEqual(sys.stderr, handler.stream)
stream = mock.Mock()
handler = logging.PythonHandler(stream)
self.assertEqual(stream, handler.stream)
def test_flush(self):
stream = mock.Mock()
handler = logging.PythonHandler(stream)
handler.flush()
stream.flush.assert_called_once()
def test_flush_with_value_error(self):
stream = mock.Mock()
stream.flush.side_effect = ValueError
handler = logging.PythonHandler(stream)
handler.flush()
stream.flush.assert_called_once()
def test_flush_with_environment_error(self):
stream = mock.Mock()
stream.flush.side_effect = EnvironmentError
handler = logging.PythonHandler(stream)
handler.flush()
stream.flush.assert_called_once()
def test_flush_with_assertion_error(self):
stream = mock.Mock()
stream.flush.side_effect = AssertionError
handler = logging.PythonHandler(stream)
with self.assertRaises(AssertionError):
handler.flush()
def test_ignore_flush_if_stream_is_none(self):
# Happens if creating a Windows executable without console.
with mock.patch.object(sys, 'stderr', new=None):
handler = logging.PythonHandler(None)
# Test that this does not fail.
handler.flush()
def test_ignore_flush_if_stream_does_not_support_flushing(self):
class BadStream:
pass
handler = logging.PythonHandler(BadStream())
# Test that this does not fail.
handler.flush()
def test_log_to_std_err(self):
record = std_logging.LogRecord(
'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
with mock.patch.object(std_logging.StreamHandler, 'emit'):
self.python_handler._log_to_stderr(record)
std_logging.StreamHandler.emit.assert_called_once_with(record)
@flagsaver.flagsaver(logtostderr=True)
def test_emit_log_to_stderr(self):
record = std_logging.LogRecord(
'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
with mock.patch.object(self.python_handler, '_log_to_stderr'):
self.python_handler.emit(record)
self.python_handler._log_to_stderr.assert_called_once_with(record)
def test_emit(self):
stream = io.StringIO()
handler = logging.PythonHandler(stream)
handler.stderr_threshold = std_logging.FATAL
record = std_logging.LogRecord(
'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
handler.emit(record)
self.assertEqual(1, stream.getvalue().count('logging_msg'))
@flagsaver.flagsaver(stderrthreshold='debug')
def test_emit_and_stderr_threshold(self):
mock_stderr = io.StringIO()
stream = io.StringIO()
handler = logging.PythonHandler(stream)
record = std_logging.LogRecord(
'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
with mock.patch.object(sys, 'stderr', new=mock_stderr) as mock_stderr:
handler.emit(record)
self.assertEqual(1, stream.getvalue().count('logging_msg'))
self.assertEqual(1, mock_stderr.getvalue().count('logging_msg'))
@flagsaver.flagsaver(alsologtostderr=True)
def test_emit_also_log_to_stderr(self):
mock_stderr = io.StringIO()
stream = io.StringIO()
handler = logging.PythonHandler(stream)
handler.stderr_threshold = std_logging.FATAL
record = std_logging.LogRecord(
'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
with mock.patch.object(sys, 'stderr', new=mock_stderr) as mock_stderr:
handler.emit(record)
self.assertEqual(1, stream.getvalue().count('logging_msg'))
self.assertEqual(1, mock_stderr.getvalue().count('logging_msg'))
def test_emit_on_stderr(self):
mock_stderr = io.StringIO()
with mock.patch.object(sys, 'stderr', new=mock_stderr) as mock_stderr:
handler = logging.PythonHandler()
handler.stderr_threshold = std_logging.INFO
record = std_logging.LogRecord(
'name', std_logging.INFO, 'path', 12, 'logging_msg', [], False)
handler.emit(record)
self.assertEqual(1, mock_stderr.getvalue().count('logging_msg'))
def test_emit_fatal_absl(self):
stream = io.StringIO()
handler = logging.PythonHandler(stream)
record = std_logging.LogRecord(
'name', std_logging.FATAL, 'path', 12, 'logging_msg', [], False)
record.__dict__[logging._ABSL_LOG_FATAL] = True
with mock.patch.object(handler, 'flush') as mock_flush:
with mock.patch.object(os, 'abort') as mock_abort:
handler.emit(record)
mock_abort.assert_called_once()
mock_flush.assert_called() # flush is also called by super class.
def test_emit_fatal_non_absl(self):
stream = io.StringIO()
handler = logging.PythonHandler(stream)
record = std_logging.LogRecord(
'name', std_logging.FATAL, 'path', 12, 'logging_msg', [], False)
with mock.patch.object(os, 'abort') as mock_abort:
handler.emit(record)
mock_abort.assert_not_called()
def test_close(self):
stream = mock.Mock()
stream.isatty.return_value = True
handler = logging.PythonHandler(stream)
with mock.patch.object(handler, 'flush') as mock_flush:
with mock.patch.object(std_logging.StreamHandler, 'close') as super_close:
handler.close()
mock_flush.assert_called_once()
super_close.assert_called_once()
stream.close.assert_not_called()
def test_close_afile(self):
stream = mock.Mock()
stream.isatty.return_value = False
stream.close.side_effect = ValueError
handler = logging.PythonHandler(stream)
with mock.patch.object(handler, 'flush') as mock_flush:
with mock.patch.object(std_logging.StreamHandler, 'close') as super_close:
handler.close()
mock_flush.assert_called_once()
super_close.assert_called_once()
def test_close_stderr(self):
with mock.patch.object(sys, 'stderr') as mock_stderr:
mock_stderr.isatty.return_value = False
handler = logging.PythonHandler(sys.stderr)
handler.close()
mock_stderr.close.assert_not_called()
def test_close_stdout(self):
with mock.patch.object(sys, 'stdout') as mock_stdout:
mock_stdout.isatty.return_value = False
handler = logging.PythonHandler(sys.stdout)
handler.close()
mock_stdout.close.assert_not_called()
def test_close_original_stderr(self):
with mock.patch.object(sys, '__stderr__') as mock_original_stderr:
mock_original_stderr.isatty.return_value = False
handler = logging.PythonHandler(sys.__stderr__)
handler.close()
mock_original_stderr.close.assert_not_called()
def test_close_original_stdout(self):
with mock.patch.object(sys, '__stdout__') as mock_original_stdout:
mock_original_stdout.isatty.return_value = False
handler = logging.PythonHandler(sys.__stdout__)
handler.close()
mock_original_stdout.close.assert_not_called()
def test_close_fake_file(self):
class FakeFile(object):
"""A file-like object that does not implement "isatty"."""
def __init__(self):
self.closed = False
def close(self):
self.closed = True
def flush(self):
pass
fake_file = FakeFile()
handler = logging.PythonHandler(fake_file)
handler.close()
self.assertTrue(fake_file.closed)
class ABSLHandlerTest(absltest.TestCase):
def setUp(self):
super().setUp()
formatter = logging.PythonFormatter()
self.absl_handler = logging.ABSLHandler(formatter)
def test_activate_python_handler(self):
self.absl_handler.activate_python_handler()
self.assertEqual(
self.absl_handler._current_handler, self.absl_handler.python_handler)
class ABSLLoggerTest(absltest.TestCase):
"""Tests the ABSLLogger class."""
def set_up_mock_frames(self):
"""Sets up mock frames for use with the testFindCaller methods."""
logging_file = os.path.join('absl', 'logging', '__init__.py')
# Set up mock frame 0
mock_frame_0 = mock.Mock()
mock_code_0 = mock.Mock()
mock_code_0.co_filename = logging_file
mock_code_0.co_name = 'LoggingLog'
mock_code_0.co_firstlineno = 124
mock_frame_0.f_code = mock_code_0
mock_frame_0.f_lineno = 125
# Set up mock frame 1
mock_frame_1 = mock.Mock()
mock_code_1 = mock.Mock()
mock_code_1.co_filename = 'myfile.py'
mock_code_1.co_name = 'Method1'
mock_code_1.co_firstlineno = 124
mock_frame_1.f_code = mock_code_1
mock_frame_1.f_lineno = 125
# Set up mock frame 2
mock_frame_2 = mock.Mock()
mock_code_2 = mock.Mock()
mock_code_2.co_filename = 'myfile.py'
mock_code_2.co_name = 'Method2'
mock_code_2.co_firstlineno = 124
mock_frame_2.f_code = mock_code_2
mock_frame_2.f_lineno = 125
# Set up mock frame 3
mock_frame_3 = mock.Mock()
mock_code_3 = mock.Mock()
mock_code_3.co_filename = 'myfile.py'
mock_code_3.co_name = 'Method3'
mock_code_3.co_firstlineno = 124
mock_frame_3.f_code = mock_code_3
mock_frame_3.f_lineno = 125
# Set up mock frame 4 that has the same function name as frame 2.
mock_frame_4 = mock.Mock()
mock_code_4 = mock.Mock()
mock_code_4.co_filename = 'myfile.py'
mock_code_4.co_name = 'Method2'
mock_code_4.co_firstlineno = 248
mock_frame_4.f_code = mock_code_4
mock_frame_4.f_lineno = 249
# Tie them together.
mock_frame_4.f_back = None
mock_frame_3.f_back = mock_frame_4
mock_frame_2.f_back = mock_frame_3
mock_frame_1.f_back = mock_frame_2
mock_frame_0.f_back = mock_frame_1
mock.patch.object(sys, '_getframe').start()
sys._getframe.return_value = mock_frame_0
def setUp(self):
super().setUp()
self.message = 'Hello Nurse'
self.logger = logging.ABSLLogger('')
def tearDown(self):
mock.patch.stopall()
self.logger._frames_to_skip.clear()
super().tearDown()
def test_constructor_without_level(self):
self.logger = logging.ABSLLogger('')
self.assertEqual(std_logging.NOTSET, self.logger.getEffectiveLevel())
def test_constructor_with_level(self):
self.logger = logging.ABSLLogger('', std_logging.DEBUG)
self.assertEqual(std_logging.DEBUG, self.logger.getEffectiveLevel())
def test_find_caller_normal(self):
self.set_up_mock_frames()
expected_name = 'Method1'
self.assertEqual(expected_name, self.logger.findCaller()[2])
def test_find_caller_skip_method1(self):
self.set_up_mock_frames()
self.logger.register_frame_to_skip('myfile.py', 'Method1')
expected_name = 'Method2'
self.assertEqual(expected_name, self.logger.findCaller()[2])
def test_find_caller_skip_method1_and_method2(self):
self.set_up_mock_frames()
self.logger.register_frame_to_skip('myfile.py', 'Method1')
self.logger.register_frame_to_skip('myfile.py', 'Method2')
expected_name = 'Method3'
self.assertEqual(expected_name, self.logger.findCaller()[2])
def test_find_caller_skip_method1_and_method3(self):
self.set_up_mock_frames()
self.logger.register_frame_to_skip('myfile.py', 'Method1')
# Skipping Method3 should change nothing since Method2 should be hit.
self.logger.register_frame_to_skip('myfile.py', 'Method3')
expected_name = 'Method2'
self.assertEqual(expected_name, self.logger.findCaller()[2])
def test_find_caller_skip_method1_and_method4(self):
self.set_up_mock_frames()
self.logger.register_frame_to_skip('myfile.py', 'Method1')
# Skipping frame 4's Method2 should change nothing for frame 2's Method2.
self.logger.register_frame_to_skip('myfile.py', 'Method2', 248)
expected_name = 'Method2'
expected_frame_lineno = 125
self.assertEqual(expected_name, self.logger.findCaller()[2])
self.assertEqual(expected_frame_lineno, self.logger.findCaller()[1])
def test_find_caller_skip_method1_method2_and_method3(self):
self.set_up_mock_frames()
self.logger.register_frame_to_skip('myfile.py', 'Method1')
self.logger.register_frame_to_skip('myfile.py', 'Method2', 124)
self.logger.register_frame_to_skip('myfile.py', 'Method3')
expected_name = 'Method2'
expected_frame_lineno = 249
self.assertEqual(expected_name, self.logger.findCaller()[2])
self.assertEqual(expected_frame_lineno, self.logger.findCaller()[1])
def test_find_caller_stack_info(self):
self.set_up_mock_frames()
self.logger.register_frame_to_skip('myfile.py', 'Method1')
with mock.patch.object(traceback, 'print_stack') as print_stack:
self.assertEqual(
('myfile.py', 125, 'Method2', 'Stack (most recent call last):'),
self.logger.findCaller(stack_info=True))
print_stack.assert_called_once()
def test_critical(self):
with mock.patch.object(self.logger, 'log'):
self.logger.critical(self.message)
self.logger.log.assert_called_once_with(
std_logging.CRITICAL, self.message)
def test_fatal(self):
with mock.patch.object(self.logger, 'log'):
self.logger.fatal(self.message)
self.logger.log.assert_called_once_with(std_logging.FATAL, self.message)
def test_error(self):
with mock.patch.object(self.logger, 'log'):
self.logger.error(self.message)
self.logger.log.assert_called_once_with(std_logging.ERROR, self.message)
def test_warn(self):
with mock.patch.object(self.logger, 'log'):
self.logger.warn(self.message)
self.logger.log.assert_called_once_with(std_logging.WARN, self.message)
def test_warning(self):
with mock.patch.object(self.logger, 'log'):
self.logger.warning(self.message)
self.logger.log.assert_called_once_with(std_logging.WARNING, self.message)
def test_info(self):
with mock.patch.object(self.logger, 'log'):
self.logger.info(self.message)
self.logger.log.assert_called_once_with(std_logging.INFO, self.message)
def test_debug(self):
with mock.patch.object(self.logger, 'log'):
self.logger.debug(self.message)
self.logger.log.assert_called_once_with(std_logging.DEBUG, self.message)
def test_log_debug_with_python(self):
with mock.patch.object(self.logger, 'log'):
FLAGS.verbosity = 1
self.logger.debug(self.message)
self.logger.log.assert_called_once_with(std_logging.DEBUG, self.message)
def test_log_fatal_with_python(self):
with mock.patch.object(self.logger, 'log'):
self.logger.fatal(self.message)
self.logger.log.assert_called_once_with(std_logging.FATAL, self.message)
def test_register_frame_to_skip(self):
# This is basically just making sure that if I put something in a
# list, it actually appears in that list.
frame_tuple = ('file', 'method')
self.logger.register_frame_to_skip(*frame_tuple)
self.assertIn(frame_tuple, self.logger._frames_to_skip)
def test_register_frame_to_skip_with_lineno(self):
frame_tuple = ('file', 'method', 123)
self.logger.register_frame_to_skip(*frame_tuple)
self.assertIn(frame_tuple, self.logger._frames_to_skip)
def test_logger_cannot_be_disabled(self):
self.logger.disabled = True
record = self.logger.makeRecord(
'name', std_logging.INFO, 'fn', 20, 'msg', [], False)
with mock.patch.object(self.logger, 'callHandlers') as mock_call_handlers:
self.logger.handle(record)
mock_call_handlers.assert_called_once()
class ABSLLogPrefixTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.record = std_logging.LogRecord(
'name', std_logging.INFO, 'path/to/source.py', 13, 'log message',
None, None)
@parameterized.named_parameters(
('debug', std_logging.DEBUG, 'I'),
('info', std_logging.INFO, 'I'),
('warning', std_logging.WARNING, 'W'),
('error', std_logging.ERROR, 'E'),
)
def test_default_prefixes(self, levelno, level_prefix):
self.record.levelno = levelno
self.record.created = 1494293880.378885
thread_id = '{: >5}'.format(logging._get_thread_id())
# Use UTC so the test passes regardless of the local time zone.
with mock.patch.object(time, 'localtime', side_effect=time.gmtime):
self.assertEqual(
'{}0509 01:38:00.378885 {} source.py:13] '.format(
level_prefix, thread_id),
logging.get_absl_log_prefix(self.record))
time.localtime.assert_called_once_with(self.record.created)
def test_absl_prefix_regex(self):
self.record.created = 1226888258.0521369
# Use UTC so the test passes regardless of the local time zone.
with mock.patch.object(time, 'localtime', side_effect=time.gmtime):
prefix = logging.get_absl_log_prefix(self.record)
match = re.search(logging.ABSL_LOGGING_PREFIX_REGEX, prefix)
self.assertTrue(match)
expect = {'severity': 'I',
'month': '11',
'day': '17',
'hour': '02',
'minute': '17',
'second': '38',
'microsecond': '052136',
'thread_id': str(logging._get_thread_id()),
'filename': 'source.py',
'line': '13',
}
actual = {name: match.group(name) for name in expect}
self.assertEqual(expect, actual)
def test_critical_absl(self):
self.record.levelno = std_logging.CRITICAL
self.record.created = 1494293880.378885
self.record._absl_log_fatal = True
thread_id = '{: >5}'.format(logging._get_thread_id())
# Use UTC so the test passes regardless of the local time zone.
with mock.patch.object(time, 'localtime', side_effect=time.gmtime):
self.assertEqual(
'F0509 01:38:00.378885 {} source.py:13] '.format(thread_id),
logging.get_absl_log_prefix(self.record))
time.localtime.assert_called_once_with(self.record.created)
def test_critical_non_absl(self):
self.record.levelno = std_logging.CRITICAL
self.record.created = 1494293880.378885
thread_id = '{: >5}'.format(logging._get_thread_id())
# Use UTC so the test passes regardless of the local time zone.
with mock.patch.object(time, 'localtime', side_effect=time.gmtime):
self.assertEqual(
'E0509 01:38:00.378885 {} source.py:13] CRITICAL - '.format(
thread_id),
logging.get_absl_log_prefix(self.record))
time.localtime.assert_called_once_with(self.record.created)
class LogCountTest(absltest.TestCase):
def test_counter_threadsafe(self):
threads_start = threading.Event()
counts = set()
k = object()
def t():
threads_start.wait()
counts.add(logging._get_next_log_count_per_token(k))
threads = [threading.Thread(target=t) for _ in range(100)]
for thread in threads:
thread.start()
threads_start.set()
for thread in threads:
thread.join()
self.assertEqual(counts, {i for i in range(100)})
class LoggingTest(absltest.TestCase):
def test_fatal(self):
with mock.patch.object(os, 'abort') as mock_abort:
logging.fatal('Die!')
mock_abort.assert_called_once()
def test_find_log_dir_with_arg(self):
with mock.patch.object(os, 'access'), \
mock.patch.object(os.path, 'isdir'):
os.path.isdir.return_value = True
os.access.return_value = True
log_dir = logging.find_log_dir(log_dir='./')
self.assertEqual('./', log_dir)
@flagsaver.flagsaver(log_dir='./')
def test_find_log_dir_with_flag(self):
with mock.patch.object(os, 'access'), \
mock.patch.object(os.path, 'isdir'):
os.path.isdir.return_value = True
os.access.return_value = True
log_dir = logging.find_log_dir()
self.assertEqual('./', log_dir)
@flagsaver.flagsaver(log_dir='')
def test_find_log_dir_with_hda_tmp(self):
with mock.patch.object(os, 'access'), \
mock.patch.object(os.path, 'exists'), \
mock.patch.object(os.path, 'isdir'):
os.path.exists.return_value = True
os.path.isdir.return_value = True
os.access.return_value = True
log_dir = logging.find_log_dir()
self.assertEqual(tempfile.gettempdir(), log_dir)
@flagsaver.flagsaver(log_dir='')
def test_find_log_dir_with_tmp(self):
with mock.patch.object(os, 'access'), \
mock.patch.object(os.path, 'exists'), \
mock.patch.object(os.path, 'isdir'):
os.path.exists.return_value = False
os.path.isdir.side_effect = lambda path: path == tempfile.gettempdir()
os.access.return_value = True
log_dir = logging.find_log_dir()
self.assertEqual(tempfile.gettempdir(), log_dir)
def test_find_log_dir_with_nothing(self):
with mock.patch.object(os.path, 'exists'), \
mock.patch.object(os.path, 'isdir'):
os.path.exists.return_value = False
os.path.isdir.return_value = False
with self.assertRaises(FileNotFoundError):
logging.find_log_dir()
def test_find_log_dir_and_names_with_args(self):
user = 'test_user'
host = 'test_host'
log_dir = 'here'
program_name = 'prog1'
with mock.patch.object(getpass, 'getuser'), \
mock.patch.object(logging, 'find_log_dir') as mock_find_log_dir, \
mock.patch.object(socket, 'gethostname') as mock_gethostname:
getpass.getuser.return_value = user
mock_gethostname.return_value = host
mock_find_log_dir.return_value = log_dir
prefix = '%s.%s.%s.log' % (program_name, host, user)
self.assertEqual((log_dir, prefix, program_name),
logging.find_log_dir_and_names(
program_name=program_name, log_dir=log_dir))
def test_find_log_dir_and_names_without_args(self):
user = 'test_user'
host = 'test_host'
log_dir = 'here'
py_program_name = 'py_prog1'
sys.argv[0] = 'path/to/prog1'
with mock.patch.object(getpass, 'getuser'), \
mock.patch.object(logging, 'find_log_dir') as mock_find_log_dir, \
mock.patch.object(socket, 'gethostname'):
getpass.getuser.return_value = user
socket.gethostname.return_value = host
mock_find_log_dir.return_value = log_dir
prefix = '%s.%s.%s.log' % (py_program_name, host, user)
self.assertEqual((log_dir, prefix, py_program_name),
logging.find_log_dir_and_names())
def test_find_log_dir_and_names_wo_username(self):
# Windows doesn't have os.getuid at all
if hasattr(os, 'getuid'):
mock_getuid = mock.patch.object(os, 'getuid')
uid = 100
logged_uid = '100'
else:
# The function doesn't exist, but our test code still tries to mock
# it, so just use a fake thing.
mock_getuid = _mock_windows_os_getuid()
uid = -1
logged_uid = 'unknown'
host = 'test_host'
log_dir = 'here'
program_name = 'prog1'
with mock.patch.object(getpass, 'getuser'), \
mock_getuid as getuid, \
mock.patch.object(logging, 'find_log_dir') as mock_find_log_dir, \
mock.patch.object(socket, 'gethostname') as mock_gethostname:
getpass.getuser.side_effect = KeyError()
getuid.return_value = uid
mock_gethostname.return_value = host
mock_find_log_dir.return_value = log_dir
prefix = '%s.%s.%s.log' % (program_name, host, logged_uid)
self.assertEqual((log_dir, prefix, program_name),
logging.find_log_dir_and_names(
program_name=program_name, log_dir=log_dir))
def test_errors_in_logging(self):
with mock.patch.object(sys, 'stderr', new=io.StringIO()) as stderr:
logging.info('not enough args: %s %s', 'foo') # pylint: disable=logging-too-few-args
self.assertIn('Traceback (most recent call last):', stderr.getvalue())
self.assertIn('TypeError', stderr.getvalue())
def test_dict_arg(self):
# Tests that passing a dictionary as a single argument does not crash.
logging.info('%(test)s', {'test': 'Hello world!'})
def test_exception_dict_format(self):
# Just verify that this doesn't raise a TypeError.
logging.exception('%(test)s', {'test': 'Hello world!'})
def test_exception_with_exc_info(self):
# Just verify that this doesn't raise a KeyeError.
logging.exception('exc_info=True', exc_info=True)
logging.exception('exc_info=False', exc_info=False)
def test_logging_levels(self):
old_level = logging.get_verbosity()
logging.set_verbosity(logging.DEBUG)
self.assertEqual(logging.get_verbosity(), logging.DEBUG)
self.assertTrue(logging.level_debug())
self.assertTrue(logging.level_info())
self.assertTrue(logging.level_warning())
self.assertTrue(logging.level_error())
logging.set_verbosity(logging.INFO)
self.assertEqual(logging.get_verbosity(), logging.INFO)
self.assertFalse(logging.level_debug())
self.assertTrue(logging.level_info())
self.assertTrue(logging.level_warning())
self.assertTrue(logging.level_error())
logging.set_verbosity(logging.WARNING)
self.assertEqual(logging.get_verbosity(), logging.WARNING)
self.assertFalse(logging.level_debug())
self.assertFalse(logging.level_info())
self.assertTrue(logging.level_warning())
self.assertTrue(logging.level_error())
logging.set_verbosity(logging.ERROR)
self.assertEqual(logging.get_verbosity(), logging.ERROR)
self.assertFalse(logging.level_debug())
self.assertFalse(logging.level_info())
self.assertTrue(logging.level_error())
logging.set_verbosity(old_level)
def test_set_verbosity_strings(self):
old_level = logging.get_verbosity()
# Lowercase names.
logging.set_verbosity('debug')
self.assertEqual(logging.get_verbosity(), logging.DEBUG)
logging.set_verbosity('info')
self.assertEqual(logging.get_verbosity(), logging.INFO)
logging.set_verbosity('warning')
self.assertEqual(logging.get_verbosity(), logging.WARNING)
logging.set_verbosity('warn')
self.assertEqual(logging.get_verbosity(), logging.WARNING)
logging.set_verbosity('error')
self.assertEqual(logging.get_verbosity(), logging.ERROR)
logging.set_verbosity('fatal')
# Uppercase names.
self.assertEqual(logging.get_verbosity(), logging.FATAL)
logging.set_verbosity('DEBUG')
self.assertEqual(logging.get_verbosity(), logging.DEBUG)
logging.set_verbosity('INFO')
self.assertEqual(logging.get_verbosity(), logging.INFO)
logging.set_verbosity('WARNING')
self.assertEqual(logging.get_verbosity(), logging.WARNING)
logging.set_verbosity('WARN')
self.assertEqual(logging.get_verbosity(), logging.WARNING)
logging.set_verbosity('ERROR')
self.assertEqual(logging.get_verbosity(), logging.ERROR)
logging.set_verbosity('FATAL')
self.assertEqual(logging.get_verbosity(), logging.FATAL)
# Integers as strings.
logging.set_verbosity(str(logging.DEBUG))
self.assertEqual(logging.get_verbosity(), logging.DEBUG)
logging.set_verbosity(str(logging.INFO))
self.assertEqual(logging.get_verbosity(), logging.INFO)
logging.set_verbosity(str(logging.WARNING))
self.assertEqual(logging.get_verbosity(), logging.WARNING)
logging.set_verbosity(str(logging.ERROR))
self.assertEqual(logging.get_verbosity(), logging.ERROR)
logging.set_verbosity(str(logging.FATAL))
self.assertEqual(logging.get_verbosity(), logging.FATAL)
logging.set_verbosity(old_level)
def test_key_flags(self):
key_flags = FLAGS.get_key_flags_for_module(logging)
key_flag_names = [flag.name for flag in key_flags]
self.assertIn('stderrthreshold', key_flag_names)
self.assertIn('verbosity', key_flag_names)
def test_get_absl_logger(self):
self.assertIsInstance(
logging.get_absl_logger(), logging.ABSLLogger)
def test_get_absl_handler(self):
self.assertIsInstance(
logging.get_absl_handler(), logging.ABSLHandler)
@mock.patch.object(logging.ABSLLogger, 'register_frame_to_skip')
class LogSkipPrefixTest(absltest.TestCase):
"""Tests for logging.skip_log_prefix."""
def _log_some_info(self):
"""Logging helper function for LogSkipPrefixTest."""
logging.info('info')
def _log_nested_outer(self):
"""Nested logging helper functions for LogSkipPrefixTest."""
def _log_nested_inner():
logging.info('info nested')
return _log_nested_inner
def test_skip_log_prefix_with_name(self, mock_skip_register):
retval = logging.skip_log_prefix('_log_some_info')
mock_skip_register.assert_called_once_with(__file__, '_log_some_info', None)
self.assertEqual(retval, '_log_some_info')
def test_skip_log_prefix_with_func(self, mock_skip_register):
retval = logging.skip_log_prefix(self._log_some_info)
mock_skip_register.assert_called_once_with(
__file__, '_log_some_info', mock.ANY)
self.assertEqual(retval, self._log_some_info)
def test_skip_log_prefix_with_functools_partial(self, mock_skip_register):
partial_input = functools.partial(self._log_some_info)
with self.assertRaises(ValueError):
_ = logging.skip_log_prefix(partial_input)
mock_skip_register.assert_not_called()
def test_skip_log_prefix_with_lambda(self, mock_skip_register):
lambda_input = lambda _: self._log_some_info()
retval = logging.skip_log_prefix(lambda_input)
mock_skip_register.assert_called_once_with(__file__, '', mock.ANY)
self.assertEqual(retval, lambda_input)
def test_skip_log_prefix_with_bad_input(self, mock_skip_register):
dict_input = {1: 2, 2: 3}
with self.assertRaises(TypeError):
_ = logging.skip_log_prefix(dict_input)
mock_skip_register.assert_not_called()
def test_skip_log_prefix_with_nested_func(self, mock_skip_register):
nested_input = self._log_nested_outer()
retval = logging.skip_log_prefix(nested_input)
mock_skip_register.assert_called_once_with(
__file__, '_log_nested_inner', mock.ANY)
self.assertEqual(retval, nested_input)
def test_skip_log_prefix_decorator(self, mock_skip_register):
@logging.skip_log_prefix
def _log_decorated():
logging.info('decorated')
del _log_decorated
mock_skip_register.assert_called_once_with(
__file__, '_log_decorated', mock.ANY)
@contextlib.contextmanager
def override_python_handler_stream(stream):
handler = logging.get_absl_handler().python_handler
old_stream = handler.stream
handler.stream = stream
try:
yield
finally:
handler.stream = old_stream
class GetLogFileNameTest(parameterized.TestCase):
@parameterized.named_parameters(
('err', sys.stderr),
('out', sys.stdout),
)
def test_get_log_file_name_py_std(self, stream):
with override_python_handler_stream(stream):
self.assertEqual('', logging.get_log_file_name())
def test_get_log_file_name_py_no_name(self):
class FakeFile(object):
pass
with override_python_handler_stream(FakeFile()):
self.assertEqual('', logging.get_log_file_name())
def test_get_log_file_name_py_file(self):
_, filename = tempfile.mkstemp(dir=absltest.TEST_TMPDIR.value)
with open(filename, 'a') as stream:
with override_python_handler_stream(stream):
self.assertEqual(filename, logging.get_log_file_name())
@contextlib.contextmanager
def _mock_windows_os_getuid():
yield mock.MagicMock()
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/logging/tests/verbosity_flag_test.py 0000664 0000000 0000000 00000003514 14551576331 0023555 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests -v/--verbosity flag and logging.root level's sync behavior."""
import logging
assert logging.root.getEffectiveLevel() == logging.WARN, (
'default logging.root level should be WARN, but found {}'.format(
logging.root.getEffectiveLevel()))
# This is here to test importing logging won't change the level.
logging.root.setLevel(logging.ERROR)
assert logging.root.getEffectiveLevel() == logging.ERROR, (
'logging.root level should be changed to ERROR, but found {}'.format(
logging.root.getEffectiveLevel()))
# pylint: disable=g-import-not-at-top
from absl import flags
from absl import logging as _ # pylint: disable=unused-import
from absl.testing import absltest
# pylint: enable=g-import-not-at-top
FLAGS = flags.FLAGS
assert FLAGS['verbosity'].value == -1, (
'-v/--verbosity should be -1 before flags are parsed.')
assert logging.root.getEffectiveLevel() == logging.ERROR, (
'logging.root level should be kept to ERROR, but found {}'.format(
logging.root.getEffectiveLevel()))
class VerbosityFlagTest(absltest.TestCase):
def test_default_value_after_init(self):
self.assertEqual(0, FLAGS.verbosity)
self.assertEqual(logging.INFO, logging.root.getEffectiveLevel())
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/testing/ 0000775 0000000 0000000 00000000000 14551576331 0016007 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/absl/testing/BUILD 0000664 0000000 0000000 00000014113 14551576331 0016571 0 ustar 00root root 0000000 0000000 load("@rules_python//python:py_library.bzl", "py_library")
load("@rules_python//python:py_test.bzl", "py_test")
load("@rules_python//python:py_binary.bzl", "py_binary")
package(default_visibility = ["//visibility:private"])
licenses(["notice"])
py_library(
name = "absltest",
srcs = ["absltest.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":_pretty_print_reporter",
":xml_reporter",
"//absl:app",
"//absl/flags",
"//absl/logging",
],
)
py_library(
name = "flagsaver",
srcs = ["flagsaver.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//absl/flags",
],
)
py_library(
name = "parameterized",
srcs = [
"parameterized.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":absltest",
],
)
py_library(
name = "xml_reporter",
srcs = ["xml_reporter.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":_pretty_print_reporter",
],
)
py_library(
name = "_bazelize_command",
testonly = 1,
srcs = ["_bazelize_command.py"],
srcs_version = "PY2AND3",
visibility = ["//:__subpackages__"],
deps = [
"//absl/flags",
],
)
py_library(
name = "_pretty_print_reporter",
srcs = ["_pretty_print_reporter.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "tests/absltest_env",
testonly = True,
srcs = ["tests/absltest_env.py"],
)
py_test(
name = "tests/absltest_filtering_test",
size = "medium",
srcs = ["tests/absltest_filtering_test.py"],
data = [":tests/absltest_filtering_test_helper"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_bazelize_command",
":absltest",
":parameterized",
":tests/absltest_env",
"//absl/logging",
],
)
py_binary(
name = "tests/absltest_filtering_test_helper",
testonly = 1,
srcs = ["tests/absltest_filtering_test_helper.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":absltest",
":parameterized",
"//absl:app",
],
)
py_test(
name = "tests/absltest_fail_fast_test",
size = "small",
srcs = ["tests/absltest_fail_fast_test.py"],
data = [":tests/absltest_fail_fast_test_helper"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_bazelize_command",
":absltest",
":parameterized",
":tests/absltest_env",
"//absl/logging",
],
)
py_binary(
name = "tests/absltest_fail_fast_test_helper",
testonly = 1,
srcs = ["tests/absltest_fail_fast_test_helper.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":absltest",
"//absl:app",
],
)
py_test(
name = "tests/absltest_randomization_test",
size = "medium",
srcs = ["tests/absltest_randomization_test.py"],
data = [":tests/absltest_randomization_testcase"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_bazelize_command",
":absltest",
":parameterized",
":tests/absltest_env",
"//absl/flags",
],
)
py_binary(
name = "tests/absltest_randomization_testcase",
testonly = 1,
srcs = ["tests/absltest_randomization_testcase.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":absltest",
],
)
py_test(
name = "tests/absltest_sharding_test",
size = "small",
srcs = ["tests/absltest_sharding_test.py"],
data = [
":tests/absltest_sharding_test_helper",
":tests/absltest_sharding_test_helper_no_tests",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_bazelize_command",
":absltest",
":parameterized",
":tests/absltest_env",
],
)
py_binary(
name = "tests/absltest_sharding_test_helper",
testonly = 1,
srcs = ["tests/absltest_sharding_test_helper.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [":absltest"],
)
py_binary(
name = "tests/absltest_sharding_test_helper_no_tests",
testonly = 1,
srcs = ["tests/absltest_sharding_test_helper_no_tests.py"],
deps = [":absltest"],
)
py_test(
name = "tests/absltest_test",
size = "small",
srcs = ["tests/absltest_test.py"],
data = [
":tests/absltest_test_helper",
":tests/absltest_test_helper_skipped",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_bazelize_command",
":absltest",
":parameterized",
":tests/absltest_env",
],
)
py_binary(
name = "tests/absltest_test_helper",
testonly = 1,
srcs = ["tests/absltest_test_helper.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":absltest",
"//absl:app",
"//absl/flags",
],
)
py_binary(
name = "tests/absltest_test_helper_skipped",
testonly = 1,
srcs = ["tests/absltest_test_helper_skipped.py"],
deps = [":absltest"],
)
py_test(
name = "tests/flagsaver_test",
srcs = ["tests/flagsaver_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":absltest",
":flagsaver",
":parameterized",
"//absl/flags",
],
)
py_test(
name = "tests/parameterized_test",
srcs = ["tests/parameterized_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":absltest",
":parameterized",
],
)
py_test(
name = "tests/xml_reporter_test",
srcs = ["tests/xml_reporter_test.py"],
data = [":tests/xml_reporter_helper_test"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":_bazelize_command",
":absltest",
":parameterized",
":xml_reporter",
"//absl/logging",
],
)
py_binary(
name = "tests/xml_reporter_helper_test",
testonly = 1,
srcs = ["tests/xml_reporter_helper_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":absltest",
"//absl/flags",
],
)
abseil-py-2.1.0/absl/testing/__init__.py 0000664 0000000 0000000 00000001110 14551576331 0020111 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
abseil-py-2.1.0/absl/testing/_bazelize_command.py 0000664 0000000 0000000 00000004376 14551576331 0022035 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal helper for running tests on Windows Bazel."""
import os
from absl import flags
FLAGS = flags.FLAGS
def get_executable_path(py_binary_name):
"""Returns the executable path of a py_binary.
This returns the executable path of a py_binary that is in another Bazel
target's data dependencies.
On Linux/macOS, the path and __file__ has the same root directory.
On Windows, bazel builds an .exe file and we need to use the MANIFEST file
the location the actual binary.
Args:
py_binary_name: string, the name of a py_binary that is in another Bazel
target's data dependencies.
Raises:
RuntimeError: Raised when it cannot locate the executable path.
"""
if os.name == 'nt':
py_binary_name += '.exe'
manifest_file = os.path.join(FLAGS.test_srcdir, 'MANIFEST')
workspace_name = os.environ['TEST_WORKSPACE']
manifest_entry = '{}/{}'.format(workspace_name, py_binary_name)
with open(manifest_file, 'r') as manifest_fd:
for line in manifest_fd:
tokens = line.strip().split(' ')
if len(tokens) != 2:
continue
if manifest_entry == tokens[0]:
return tokens[1]
raise RuntimeError(
'Cannot locate executable path for {}, MANIFEST file: {}.'.format(
py_binary_name, manifest_file))
else:
# NOTE: __file__ may be .py or .pyc, depending on how the module was
# loaded and executed.
path = __file__
# Use the package name to find the root directory: every dot is
# a directory, plus one for ourselves.
for _ in range(__name__.count('.') + 1):
path = os.path.dirname(path)
root_directory = path
return os.path.join(root_directory, py_binary_name)
abseil-py-2.1.0/absl/testing/_pretty_print_reporter.py 0000664 0000000 0000000 00000006104 14551576331 0023206 0 ustar 00root root 0000000 0000000 # Copyright 2018 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TestResult implementing default output for test execution status."""
import unittest
class TextTestResult(unittest.TextTestResult):
"""TestResult class that provides the default text result formatting."""
def __init__(self, stream, descriptions, verbosity):
# Disable the verbose per-test output from the superclass, since it would
# conflict with our customized output.
super(TextTestResult, self).__init__(stream, descriptions, 0)
self._per_test_output = verbosity > 0
def _print_status(self, tag, test):
if self._per_test_output:
test_id = test.id()
if test_id.startswith('__main__.'):
test_id = test_id[len('__main__.'):]
print('[%s] %s' % (tag, test_id), file=self.stream)
self.stream.flush()
def startTest(self, test):
super(TextTestResult, self).startTest(test)
self._print_status(' RUN ', test)
def addSuccess(self, test):
super(TextTestResult, self).addSuccess(test)
self._print_status(' OK ', test)
def addError(self, test, err):
super(TextTestResult, self).addError(test, err)
self._print_status(' FAILED ', test)
def addFailure(self, test, err):
super(TextTestResult, self).addFailure(test, err)
self._print_status(' FAILED ', test)
def addSkip(self, test, reason):
super(TextTestResult, self).addSkip(test, reason)
self._print_status(' SKIPPED ', test)
def addExpectedFailure(self, test, err):
super(TextTestResult, self).addExpectedFailure(test, err)
self._print_status(' OK ', test)
def addUnexpectedSuccess(self, test):
super(TextTestResult, self).addUnexpectedSuccess(test)
self._print_status(' FAILED ', test)
class TextTestRunner(unittest.TextTestRunner):
"""A test runner that produces formatted text results."""
_TEST_RESULT_CLASS = TextTestResult
# Set this to true at the class or instance level to run tests using a
# debug-friendly method (e.g, one that doesn't catch exceptions and interacts
# better with debuggers).
# Usually this is set using --pdb_post_mortem.
run_for_debugging = False
def run(self, test):
# type: (TestCase) -> TestResult
if self.run_for_debugging:
return self._run_debug(test)
else:
return super(TextTestRunner, self).run(test)
def _run_debug(self, test):
# type: (TestCase) -> TestResult
test.debug()
# Return an empty result to indicate success.
return self._makeResult()
def _makeResult(self):
return TextTestResult(self.stream, self.descriptions, self.verbosity)
abseil-py-2.1.0/absl/testing/absltest.py 0000664 0000000 0000000 00000305377 14551576331 0020221 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base functionality for Abseil Python tests.
This module contains base classes and high-level functions for Abseil-style
tests.
"""
from collections import abc
import contextlib
import dataclasses
import difflib
import enum
import errno
import faulthandler
import getpass
import inspect
import io
import itertools
import json
import os
import random
import re
import shlex
import shutil
import signal
import stat
import subprocess
import sys
import tempfile
import textwrap
import typing
from typing import Any, AnyStr, BinaryIO, Callable, ContextManager, IO, Iterator, List, Mapping, MutableMapping, MutableSequence, NoReturn, Optional, Sequence, Text, TextIO, Tuple, Type, Union
import unittest
from unittest import mock # pylint: disable=unused-import Allow absltest.mock.
from urllib import parse
from absl import app # pylint: disable=g-import-not-at-top
from absl import flags
from absl import logging
from absl.testing import _pretty_print_reporter
from absl.testing import xml_reporter
# Use an if-type-checking block to prevent leakage of type-checking only
# symbols. We don't want people relying on these at runtime.
if typing.TYPE_CHECKING:
# Unbounded TypeVar for general usage
_T = typing.TypeVar('_T')
import unittest.case # pylint: disable=g-import-not-at-top,g-bad-import-order
_OutcomeType = unittest.case._Outcome # pytype: disable=module-attr
# pylint: enable=g-import-not-at-top
# Re-export a bunch of unittest functions we support so that people don't
# have to import unittest to get them
# pylint: disable=invalid-name
skip = unittest.skip
skipIf = unittest.skipIf
skipUnless = unittest.skipUnless
SkipTest = unittest.SkipTest
expectedFailure = unittest.expectedFailure
# pylint: enable=invalid-name
# End unittest re-exports
FLAGS = flags.FLAGS
_TEXT_OR_BINARY_TYPES = (str, bytes)
# Suppress surplus entries in AssertionError stack traces.
__unittest = True # pylint: disable=invalid-name
def expectedFailureIf(condition, reason): # pylint: disable=invalid-name
"""Expects the test to fail if the run condition is True.
Example usage::
@expectedFailureIf(sys.version.major == 2, "Not yet working in py2")
def test_foo(self):
...
Args:
condition: bool, whether to expect failure or not.
reason: Text, the reason to expect failure.
Returns:
Decorator function
"""
del reason # Unused
if condition:
return unittest.expectedFailure
else:
return lambda f: f
class TempFileCleanup(enum.Enum):
# Always cleanup temp files when the test completes.
ALWAYS = 'always'
# Only cleanup temp file if the test passes. This allows easier inspection
# of tempfile contents on test failure. absltest.TEST_TMPDIR.value determines
# where tempfiles are created.
SUCCESS = 'success'
# Never cleanup temp files.
OFF = 'never'
# Many of the methods in this module have names like assertSameElements.
# This kind of name does not comply with PEP8 style,
# but it is consistent with the naming of methods in unittest.py.
# pylint: disable=invalid-name
def _get_default_test_random_seed():
# type: () -> int
random_seed = 301
value = os.environ.get('TEST_RANDOM_SEED', '')
try:
random_seed = int(value)
except ValueError:
pass
return random_seed
def get_default_test_srcdir():
# type: () -> Text
"""Returns default test source dir."""
return os.environ.get('TEST_SRCDIR', '')
def get_default_test_tmpdir():
# type: () -> Text
"""Returns default test temp dir."""
tmpdir = os.environ.get('TEST_TMPDIR', '')
if not tmpdir:
tmpdir = os.path.join(tempfile.gettempdir(), 'absl_testing')
return tmpdir
def _get_default_randomize_ordering_seed():
# type: () -> int
"""Returns default seed to use for randomizing test order.
This function first checks the --test_randomize_ordering_seed flag, and then
the TEST_RANDOMIZE_ORDERING_SEED environment variable. If the first value
we find is:
* (not set): disable test randomization
* 0: disable test randomization
* 'random': choose a random seed in [1, 4294967295] for test order
randomization
* positive integer: use this seed for test order randomization
(The values used are patterned after
https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED).
In principle, it would be simpler to return None if no override is provided;
however, the python random module has no `get_seed()`, only `getstate()`,
which returns far more data than we want to pass via an environment variable
or flag.
Returns:
A default value for test case randomization (int). 0 means do not randomize.
Raises:
ValueError: Raised when the flag or env value is not one of the options
above.
"""
if FLAGS['test_randomize_ordering_seed'].present:
randomize = FLAGS.test_randomize_ordering_seed
elif 'TEST_RANDOMIZE_ORDERING_SEED' in os.environ:
randomize = os.environ['TEST_RANDOMIZE_ORDERING_SEED']
else:
randomize = ''
if not randomize:
return 0
if randomize == 'random':
return random.Random().randint(1, 4294967295)
if randomize == '0':
return 0
try:
seed = int(randomize)
if seed > 0:
return seed
except ValueError:
pass
raise ValueError(
'Unknown test randomization seed value: {}'.format(randomize))
TEST_SRCDIR = flags.DEFINE_string(
'test_srcdir',
get_default_test_srcdir(),
'Root of directory tree where source files live',
allow_override_cpp=True)
TEST_TMPDIR = flags.DEFINE_string(
'test_tmpdir',
get_default_test_tmpdir(),
'Directory for temporary testing files',
allow_override_cpp=True)
flags.DEFINE_integer(
'test_random_seed',
_get_default_test_random_seed(),
'Random seed for testing. Some test frameworks may '
'change the default value of this flag between runs, so '
'it is not appropriate for seeding probabilistic tests.',
allow_override_cpp=True)
flags.DEFINE_string(
'test_randomize_ordering_seed',
'',
'If positive, use this as a seed to randomize the '
'execution order for test cases. If "random", pick a '
'random seed to use. If 0 or not set, do not randomize '
'test case execution order. This flag also overrides '
'the TEST_RANDOMIZE_ORDERING_SEED environment variable.',
allow_override_cpp=True)
flags.DEFINE_string('xml_output_file', '', 'File to store XML test results')
# We might need to monkey-patch TestResult so that it stops considering an
# unexpected pass as a as a "successful result". For details, see
# http://bugs.python.org/issue20165
def _monkey_patch_test_result_for_unexpected_passes():
# type: () -> None
"""Workaround for ."""
def wasSuccessful(self):
# type: () -> bool
"""Tells whether or not this result was a success.
Any unexpected pass is to be counted as a non-success.
Args:
self: The TestResult instance.
Returns:
Whether or not this result was a success.
"""
return (len(self.failures) == len(self.errors) ==
len(self.unexpectedSuccesses) == 0)
test_result = unittest.TestResult()
test_result.addUnexpectedSuccess(unittest.FunctionTestCase(lambda: None))
if test_result.wasSuccessful(): # The bug is present.
unittest.TestResult.wasSuccessful = wasSuccessful
if test_result.wasSuccessful(): # Warn the user if our hot-fix failed.
sys.stderr.write('unittest.result.TestResult monkey patch to report'
' unexpected passes as failures did not work.\n')
_monkey_patch_test_result_for_unexpected_passes()
def _open(filepath, mode, _open_func=open):
# type: (Text, Text, Callable[..., IO]) -> IO
"""Opens a file.
Like open(), but ensure that we can open real files even if tests stub out
open().
Args:
filepath: A filepath.
mode: A mode.
_open_func: A built-in open() function.
Returns:
The opened file object.
"""
return _open_func(filepath, mode, encoding='utf-8')
class _TempDir(object):
"""Represents a temporary directory for tests.
Creation of this class is internal. Using its public methods is OK.
This class implements the `os.PathLike` interface (specifically,
`os.PathLike[str]`). This means, in Python 3, it can be directly passed
to e.g. `os.path.join()`.
"""
def __init__(self, path):
# type: (Text) -> None
"""Module-private: do not instantiate outside module."""
self._path = path
@property
def full_path(self):
# type: () -> Text
"""Returns the path, as a string, for the directory.
TIP: Instead of e.g. `os.path.join(temp_dir.full_path)`, you can simply
do `os.path.join(temp_dir)` because `__fspath__()` is implemented.
"""
return self._path
def __fspath__(self):
# type: () -> Text
"""See os.PathLike."""
return self.full_path
def create_file(self, file_path=None, content=None, mode='w', encoding='utf8',
errors='strict'):
# type: (Optional[Text], Optional[AnyStr], Text, Text, Text) -> _TempFile
"""Create a file in the directory.
NOTE: If the file already exists, it will be made writable and overwritten.
Args:
file_path: Optional file path for the temp file. If not given, a unique
file name will be generated and used. Slashes are allowed in the name;
any missing intermediate directories will be created. NOTE: This path
is the path that will be cleaned up, including any directories in the
path, e.g., 'foo/bar/baz.txt' will `rm -r foo`
content: Optional string or bytes to initially write to the file. If not
specified, then an empty file is created.
mode: Mode string to use when writing content. Only used if `content` is
non-empty.
encoding: Encoding to use when writing string content. Only used if
`content` is text.
errors: How to handle text to bytes encoding errors. Only used if
`content` is text.
Returns:
A _TempFile representing the created file.
"""
tf, _ = _TempFile._create(self._path, file_path, content, mode, encoding,
errors)
return tf
def mkdir(self, dir_path=None):
# type: (Optional[Text]) -> _TempDir
"""Create a directory in the directory.
Args:
dir_path: Optional path to the directory to create. If not given,
a unique name will be generated and used.
Returns:
A _TempDir representing the created directory.
"""
if dir_path:
path = os.path.join(self._path, dir_path)
else:
path = tempfile.mkdtemp(dir=self._path)
# Note: there's no need to clear the directory since the containing
# dir was cleared by the tempdir() function.
os.makedirs(path, exist_ok=True)
return _TempDir(path)
class _TempFile(object):
"""Represents a tempfile for tests.
Creation of this class is internal. Using its public methods is OK.
This class implements the `os.PathLike` interface (specifically,
`os.PathLike[str]`). This means, in Python 3, it can be directly passed
to e.g. `os.path.join()`.
"""
def __init__(self, path):
# type: (Text) -> None
"""Private: use _create instead."""
self._path = path
# pylint: disable=line-too-long
@classmethod
def _create(cls, base_path, file_path, content, mode, encoding, errors):
# type: (Text, Optional[Text], AnyStr, Text, Text, Text) -> Tuple[_TempFile, Text]
# pylint: enable=line-too-long
"""Module-private: create a tempfile instance."""
if file_path:
cleanup_path = os.path.join(base_path, _get_first_part(file_path))
path = os.path.join(base_path, file_path)
os.makedirs(os.path.dirname(path), exist_ok=True)
# The file may already exist, in which case, ensure it's writable so that
# it can be truncated.
if os.path.exists(path) and not os.access(path, os.W_OK):
stat_info = os.stat(path)
os.chmod(path, stat_info.st_mode | stat.S_IWUSR)
else:
os.makedirs(base_path, exist_ok=True)
fd, path = tempfile.mkstemp(dir=str(base_path))
os.close(fd)
cleanup_path = path
tf = cls(path)
if content:
if isinstance(content, str):
tf.write_text(content, mode=mode, encoding=encoding, errors=errors)
else:
tf.write_bytes(content, mode)
else:
tf.write_bytes(b'')
return tf, cleanup_path
@property
def full_path(self):
# type: () -> Text
"""Returns the path, as a string, for the file.
TIP: Instead of e.g. `os.path.join(temp_file.full_path)`, you can simply
do `os.path.join(temp_file)` because `__fspath__()` is implemented.
"""
return self._path
def __fspath__(self):
# type: () -> Text
"""See os.PathLike."""
return self.full_path
def read_text(self, encoding='utf8', errors='strict'):
# type: (Text, Text) -> Text
"""Return the contents of the file as text."""
with self.open_text(encoding=encoding, errors=errors) as fp:
return fp.read()
def read_bytes(self):
# type: () -> bytes
"""Return the content of the file as bytes."""
with self.open_bytes() as fp:
return fp.read()
def write_text(self, text, mode='w', encoding='utf8', errors='strict'):
# type: (Text, Text, Text, Text) -> None
"""Write text to the file.
Args:
text: Text to write. In Python 2, it can be bytes, which will be
decoded using the `encoding` arg (this is as an aid for code that
is 2 and 3 compatible).
mode: The mode to open the file for writing.
encoding: The encoding to use when writing the text to the file.
errors: The error handling strategy to use when converting text to bytes.
"""
with self.open_text(mode, encoding=encoding, errors=errors) as fp:
fp.write(text)
def write_bytes(self, data, mode='wb'):
# type: (bytes, Text) -> None
"""Write bytes to the file.
Args:
data: bytes to write.
mode: Mode to open the file for writing. The "b" flag is implicit if
not already present. It must not have the "t" flag.
"""
with self.open_bytes(mode) as fp:
fp.write(data)
def open_text(self, mode='rt', encoding='utf8', errors='strict'):
# type: (Text, Text, Text) -> ContextManager[TextIO]
"""Return a context manager for opening the file in text mode.
Args:
mode: The mode to open the file in. The "t" flag is implicit if not
already present. It must not have the "b" flag.
encoding: The encoding to use when opening the file.
errors: How to handle decoding errors.
Returns:
Context manager that yields an open file.
Raises:
ValueError: if invalid inputs are provided.
"""
if 'b' in mode:
raise ValueError('Invalid mode {!r}: "b" flag not allowed when opening '
'file in text mode'.format(mode))
if 't' not in mode:
mode += 't'
cm = self._open(mode, encoding, errors)
return cm
def open_bytes(self, mode='rb'):
# type: (Text) -> ContextManager[BinaryIO]
"""Return a context manager for opening the file in binary mode.
Args:
mode: The mode to open the file in. The "b" mode is implicit if not
already present. It must not have the "t" flag.
Returns:
Context manager that yields an open file.
Raises:
ValueError: if invalid inputs are provided.
"""
if 't' in mode:
raise ValueError('Invalid mode {!r}: "t" flag not allowed when opening '
'file in binary mode'.format(mode))
if 'b' not in mode:
mode += 'b'
cm = self._open(mode, encoding=None, errors=None)
return cm
# TODO(b/123775699): Once pytype supports typing.Literal, use overload and
# Literal to express more precise return types. The contained type is
# currently `Any` to avoid [bad-return-type] errors in the open_* methods.
@contextlib.contextmanager
def _open(
self,
mode: str,
encoding: Optional[str] = 'utf8',
errors: Optional[str] = 'strict',
) -> Iterator[Any]:
with io.open(
self.full_path, mode=mode, encoding=encoding, errors=errors) as fp:
yield fp
class _method(object):
"""A decorator that supports both instance and classmethod invocations.
Using similar semantics to the @property builtin, this decorator can augment
an instance method to support conditional logic when invoked on a class
object. This breaks support for invoking an instance method via the class
(e.g. Cls.method(self, ...)) but is still situationally useful.
"""
def __init__(self, finstancemethod):
# type: (Callable[..., Any]) -> None
self._finstancemethod = finstancemethod
self._fclassmethod = None
def classmethod(self, fclassmethod):
# type: (Callable[..., Any]) -> _method
self._fclassmethod = classmethod(fclassmethod)
return self
def __doc__(self):
# type: () -> str
if getattr(self._finstancemethod, '__doc__'):
return self._finstancemethod.__doc__
elif getattr(self._fclassmethod, '__doc__'):
return self._fclassmethod.__doc__
return ''
def __get__(self, obj, type_):
# type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any]
func = self._fclassmethod if obj is None else self._finstancemethod
return func.__get__(obj, type_) # pytype: disable=attribute-error
class TestCase(unittest.TestCase):
"""Extension of unittest.TestCase providing more power."""
# When to cleanup files/directories created by our `create_tempfile()` and
# `create_tempdir()` methods after each test case completes. This does *not*
# affect e.g., files created outside of those methods, e.g., using the stdlib
# tempfile module. This can be overridden at the class level, instance level,
# or with the `cleanup` arg of `create_tempfile()` and `create_tempdir()`. See
# `TempFileCleanup` for details on the different values.
# TODO(b/70517332): Remove the type comment and the disable once pytype has
# better support for enums.
tempfile_cleanup = TempFileCleanup.ALWAYS # type: TempFileCleanup # pytype: disable=annotation-type-mismatch
maxDiff = 80 * 20
longMessage = True
# Exit stacks for per-test and per-class scopes.
if sys.version_info < (3, 11):
_exit_stack = None
_cls_exit_stack = None
def __init__(self, *args, **kwargs):
super(TestCase, self).__init__(*args, **kwargs)
# This is to work around missing type stubs in unittest.pyi
self._outcome = getattr(self, '_outcome') # type: Optional[_OutcomeType]
def setUp(self):
super(TestCase, self).setUp()
# NOTE: Only Python 3 contextlib has ExitStack and
# Python 3.11+ already has enterContext.
if hasattr(contextlib, 'ExitStack') and sys.version_info < (3, 11):
self._exit_stack = contextlib.ExitStack()
self.addCleanup(self._exit_stack.close)
@classmethod
def setUpClass(cls):
super(TestCase, cls).setUpClass()
# NOTE: Only Python 3 contextlib has ExitStack, only Python 3.8+ has
# addClassCleanup and Python 3.11+ already has enterClassContext.
if (
hasattr(contextlib, 'ExitStack')
and hasattr(cls, 'addClassCleanup')
and sys.version_info < (3, 11)
):
cls._cls_exit_stack = contextlib.ExitStack()
cls.addClassCleanup(cls._cls_exit_stack.close)
def create_tempdir(self, name=None, cleanup=None):
# type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir
"""Create a temporary directory specific to the test.
NOTE: The directory and its contents will be recursively cleared before
creation. This ensures that there is no pre-existing state.
This creates a named directory on disk that is isolated to this test, and
will be properly cleaned up by the test. This avoids several pitfalls of
creating temporary directories for test purposes, as well as makes it easier
to setup directories and verify their contents. For example::
def test_foo(self):
out_dir = self.create_tempdir()
out_log = out_dir.create_file('output.log')
expected_outputs = [
os.path.join(out_dir, 'data-0.txt'),
os.path.join(out_dir, 'data-1.txt'),
]
code_under_test(out_dir)
self.assertTrue(os.path.exists(expected_paths[0]))
self.assertTrue(os.path.exists(expected_paths[1]))
self.assertEqual('foo', out_log.read_text())
See also: :meth:`create_tempfile` for creating temporary files.
Args:
name: Optional name of the directory. If not given, a unique
name will be generated and used.
cleanup: Optional cleanup policy on when/if to remove the directory (and
all its contents) at the end of the test. If None, then uses
:attr:`tempfile_cleanup`.
Returns:
A _TempDir representing the created directory; see _TempDir class docs
for usage.
"""
test_path = self._get_tempdir_path_test()
if name:
path = os.path.join(test_path, name)
cleanup_path = os.path.join(test_path, _get_first_part(name))
else:
os.makedirs(test_path, exist_ok=True)
path = tempfile.mkdtemp(dir=test_path)
cleanup_path = path
_rmtree_ignore_errors(cleanup_path)
os.makedirs(path, exist_ok=True)
self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
return _TempDir(path)
# pylint: disable=line-too-long
def create_tempfile(self, file_path=None, content=None, mode='w',
encoding='utf8', errors='strict', cleanup=None):
# type: (Optional[Text], Optional[AnyStr], Text, Text, Text, Optional[TempFileCleanup]) -> _TempFile
# pylint: enable=line-too-long
"""Create a temporary file specific to the test.
This creates a named file on disk that is isolated to this test, and will
be properly cleaned up by the test. This avoids several pitfalls of
creating temporary files for test purposes, as well as makes it easier
to setup files, their data, read them back, and inspect them when
a test fails. For example::
def test_foo(self):
output = self.create_tempfile()
code_under_test(output)
self.assertGreater(os.path.getsize(output), 0)
self.assertEqual('foo', output.read_text())
NOTE: This will zero-out the file. This ensures there is no pre-existing
state.
NOTE: If the file already exists, it will be made writable and overwritten.
See also: :meth:`create_tempdir` for creating temporary directories, and
``_TempDir.create_file`` for creating files within a temporary directory.
Args:
file_path: Optional file path for the temp file. If not given, a unique
file name will be generated and used. Slashes are allowed in the name;
any missing intermediate directories will be created. NOTE: This path is
the path that will be cleaned up, including any directories in the path,
e.g., ``'foo/bar/baz.txt'`` will ``rm -r foo``.
content: Optional string or
bytes to initially write to the file. If not
specified, then an empty file is created.
mode: Mode string to use when writing content. Only used if `content` is
non-empty.
encoding: Encoding to use when writing string content. Only used if
`content` is text.
errors: How to handle text to bytes encoding errors. Only used if
`content` is text.
cleanup: Optional cleanup policy on when/if to remove the directory (and
all its contents) at the end of the test. If None, then uses
:attr:`tempfile_cleanup`.
Returns:
A _TempFile representing the created file; see _TempFile class docs for
usage.
"""
test_path = self._get_tempdir_path_test()
tf, cleanup_path = _TempFile._create(test_path, file_path, content=content,
mode=mode, encoding=encoding,
errors=errors)
self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
return tf
@_method
def enter_context(self, manager):
# type: (ContextManager[_T]) -> _T
"""Returns the CM's value after registering it with the exit stack.
Entering a context pushes it onto a stack of contexts. When `enter_context`
is called on the test instance (e.g. `self.enter_context`), the context is
exited after the test case's tearDown call. When called on the test class
(e.g. `TestCase.enter_context`), the context is exited after the test
class's tearDownClass call.
Contexts are exited in the reverse order of entering. They will always
be exited, regardless of test failure/success.
This is useful to eliminate per-test boilerplate when context managers
are used. For example, instead of decorating every test with `@mock.patch`,
simply do `self.foo = self.enter_context(mock.patch(...))' in `setUp()`.
NOTE: The context managers will always be exited without any error
information. This is an unfortunate implementation detail due to some
internals of how unittest runs tests.
Args:
manager: The context manager to enter.
"""
if sys.version_info >= (3, 11):
return self.enterContext(manager)
if not self._exit_stack:
raise AssertionError(
'self._exit_stack is not set: enter_context is Py3-only; also make '
'sure that AbslTest.setUp() is called.')
return self._exit_stack.enter_context(manager)
@enter_context.classmethod
def enter_context(cls, manager): # pylint: disable=no-self-argument
# type: (ContextManager[_T]) -> _T
if sys.version_info >= (3, 11):
return cls.enterClassContext(manager)
if not cls._cls_exit_stack:
raise AssertionError(
'cls._cls_exit_stack is not set: cls.enter_context requires '
'Python 3.8+; also make sure that AbslTest.setUpClass() is called.')
return cls._cls_exit_stack.enter_context(manager)
@classmethod
def _get_tempdir_path_cls(cls):
# type: () -> Text
return os.path.join(TEST_TMPDIR.value,
cls.__qualname__.replace('__main__.', ''))
def _get_tempdir_path_test(self):
# type: () -> Text
return os.path.join(self._get_tempdir_path_cls(), self._testMethodName)
def _get_tempfile_cleanup(self, override):
# type: (Optional[TempFileCleanup]) -> TempFileCleanup
if override is not None:
return override
return self.tempfile_cleanup
def _maybe_add_temp_path_cleanup(self, path, cleanup):
# type: (Text, Optional[TempFileCleanup]) -> None
cleanup = self._get_tempfile_cleanup(cleanup)
if cleanup == TempFileCleanup.OFF:
return
elif cleanup == TempFileCleanup.ALWAYS:
self.addCleanup(_rmtree_ignore_errors, path)
elif cleanup == TempFileCleanup.SUCCESS:
self._internal_add_cleanup_on_success(_rmtree_ignore_errors, path)
else:
raise AssertionError('Unexpected cleanup value: {}'.format(cleanup))
def _internal_add_cleanup_on_success(
self,
function: Callable[..., Any],
*args: Any,
**kwargs: Any,
) -> None:
"""Adds `function` as cleanup when the test case succeeds."""
outcome = self._outcome
assert outcome is not None
previous_failure_count = (
len(outcome.result.failures)
+ len(outcome.result.errors)
+ len(outcome.result.unexpectedSuccesses)
)
def _call_cleaner_on_success(*args, **kwargs):
if not self._internal_ran_and_passed_when_called_during_cleanup(
previous_failure_count):
return
function(*args, **kwargs)
self.addCleanup(_call_cleaner_on_success, *args, **kwargs)
def _internal_ran_and_passed_when_called_during_cleanup(
self,
previous_failure_count: int,
) -> bool:
"""Returns whether test is passed. Expected to be called during cleanup."""
outcome = self._outcome
if sys.version_info[:2] >= (3, 11):
assert outcome is not None
current_failure_count = (
len(outcome.result.failures)
+ len(outcome.result.errors)
+ len(outcome.result.unexpectedSuccesses)
)
return current_failure_count == previous_failure_count
else:
# Before Python 3.11 https://github.com/python/cpython/pull/28180, errors
# were bufferred in _Outcome before calling cleanup.
result = self.defaultTestResult()
self._feedErrorsToResult(result, outcome.errors) # pytype: disable=attribute-error
return result.wasSuccessful()
def shortDescription(self):
# type: () -> Text
"""Formats both the test method name and the first line of its docstring.
If no docstring is given, only returns the method name.
This method overrides unittest.TestCase.shortDescription(), which
only returns the first line of the docstring, obscuring the name
of the test upon failure.
Returns:
desc: A short description of a test method.
"""
desc = self.id()
# Omit the main name so that test name can be directly copy/pasted to
# the command line.
if desc.startswith('__main__.'):
desc = desc[len('__main__.'):]
# NOTE: super() is used here instead of directly invoking
# unittest.TestCase.shortDescription(self), because of the
# following line that occurs later on:
# unittest.TestCase = TestCase
# Because of this, direct invocation of what we think is the
# superclass will actually cause infinite recursion.
doc_first_line = super(TestCase, self).shortDescription()
if doc_first_line is not None:
desc = '\n'.join((desc, doc_first_line))
return desc
def assertStartsWith(self, actual, expected_start, msg=None):
"""Asserts that actual.startswith(expected_start) is True.
Args:
actual: str
expected_start: str
msg: Optional message to report on failure.
"""
if not actual.startswith(expected_start):
self.fail('%r does not start with %r' % (actual, expected_start), msg)
def assertNotStartsWith(self, actual, unexpected_start, msg=None):
"""Asserts that actual.startswith(unexpected_start) is False.
Args:
actual: str
unexpected_start: str
msg: Optional message to report on failure.
"""
if actual.startswith(unexpected_start):
self.fail('%r does start with %r' % (actual, unexpected_start), msg)
def assertEndsWith(self, actual, expected_end, msg=None):
"""Asserts that actual.endswith(expected_end) is True.
Args:
actual: str
expected_end: str
msg: Optional message to report on failure.
"""
if not actual.endswith(expected_end):
self.fail('%r does not end with %r' % (actual, expected_end), msg)
def assertNotEndsWith(self, actual, unexpected_end, msg=None):
"""Asserts that actual.endswith(unexpected_end) is False.
Args:
actual: str
unexpected_end: str
msg: Optional message to report on failure.
"""
if actual.endswith(unexpected_end):
self.fail('%r does end with %r' % (actual, unexpected_end), msg)
def assertSequenceStartsWith(self, prefix, whole, msg=None):
"""An equality assertion for the beginning of ordered sequences.
If prefix is an empty sequence, it will raise an error unless whole is also
an empty sequence.
If prefix is not a sequence, it will raise an error if the first element of
whole does not match.
Args:
prefix: A sequence expected at the beginning of the whole parameter.
whole: The sequence in which to look for prefix.
msg: Optional message to report on failure.
"""
try:
prefix_len = len(prefix)
except (TypeError, NotImplementedError):
prefix = [prefix]
prefix_len = 1
if isinstance(whole, abc.Mapping) or isinstance(whole, abc.Set):
self.fail(
'For whole: Mapping or Set objects are not supported, found type: %s'
% type(whole),
msg,
)
try:
whole_len = len(whole)
except (TypeError, NotImplementedError):
self.fail('For whole: len(%s) is not supported, it appears to be type: '
'%s' % (whole, type(whole)), msg)
assert prefix_len <= whole_len, self._formatMessage(
msg,
'Prefix length (%d) is longer than whole length (%d).' %
(prefix_len, whole_len)
)
if not prefix_len and whole_len:
self.fail('Prefix length is 0 but whole length is %d: %s' %
(len(whole), whole), msg)
try:
self.assertSequenceEqual(prefix, whole[:prefix_len], msg)
except AssertionError:
self.fail('prefix: %s not found at start of whole: %s.' %
(prefix, whole), msg)
def assertEmpty(self, container, msg=None):
"""Asserts that an object has zero length.
Args:
container: Anything that implements the collections.abc.Sized interface.
msg: Optional message to report on failure.
"""
if not isinstance(container, abc.Sized):
self.fail('Expected a Sized object, got: '
'{!r}'.format(type(container).__name__), msg)
# explicitly check the length since some Sized objects (e.g. numpy.ndarray)
# have strange __nonzero__/__bool__ behavior.
if len(container): # pylint: disable=g-explicit-length-test
self.fail('{!r} has length of {}.'.format(container, len(container)), msg)
def assertNotEmpty(self, container, msg=None):
"""Asserts that an object has non-zero length.
Args:
container: Anything that implements the collections.abc.Sized interface.
msg: Optional message to report on failure.
"""
if not isinstance(container, abc.Sized):
self.fail('Expected a Sized object, got: '
'{!r}'.format(type(container).__name__), msg)
# explicitly check the length since some Sized objects (e.g. numpy.ndarray)
# have strange __nonzero__/__bool__ behavior.
if not len(container): # pylint: disable=g-explicit-length-test
self.fail('{!r} has length of 0.'.format(container), msg)
def assertLen(self, container, expected_len, msg=None):
"""Asserts that an object has the expected length.
Args:
container: Anything that implements the collections.abc.Sized interface.
expected_len: The expected length of the container.
msg: Optional message to report on failure.
"""
if not isinstance(container, abc.Sized):
self.fail('Expected a Sized object, got: '
'{!r}'.format(type(container).__name__), msg)
if len(container) != expected_len:
container_repr = unittest.util.safe_repr(container) # pytype: disable=module-attr
self.fail('{} has length of {}, expected {}.'.format(
container_repr, len(container), expected_len), msg)
def assertSequenceAlmostEqual(self, expected_seq, actual_seq, places=None,
msg=None, delta=None):
"""An approximate equality assertion for ordered sequences.
Fail if the two sequences are unequal as determined by their value
differences rounded to the given number of decimal places (default 7) and
comparing to zero, or by comparing that the difference between each value
in the two sequences is more than the given delta.
Note that decimal places (from zero) are usually not the same as significant
digits (measured from the most significant digit).
If the two sequences compare equal then they will automatically compare
almost equal.
Args:
expected_seq: A sequence containing elements we are expecting.
actual_seq: The sequence that we are testing.
places: The number of decimal places to compare.
msg: The message to be printed if the test fails.
delta: The OK difference between compared values.
"""
if len(expected_seq) != len(actual_seq):
self.fail('Sequence size mismatch: {} vs {}'.format(
len(expected_seq), len(actual_seq)), msg)
err_list = []
for idx, (exp_elem, act_elem) in enumerate(zip(expected_seq, actual_seq)):
try:
# assertAlmostEqual should be called with at most one of `places` and
# `delta`. However, it's okay for assertSequenceAlmostEqual to pass
# both because we want the latter to fail if the former does.
# pytype: disable=wrong-keyword-args
self.assertAlmostEqual(exp_elem, act_elem, places=places, msg=msg,
delta=delta)
# pytype: enable=wrong-keyword-args
except self.failureException as err:
err_list.append('At index {}: {}'.format(idx, err))
if err_list:
if len(err_list) > 30:
err_list = err_list[:30] + ['...']
msg = self._formatMessage(msg, '\n'.join(err_list))
self.fail(msg)
def assertContainsSubset(self, expected_subset, actual_set, msg=None):
"""Checks whether actual iterable is a superset of expected iterable."""
missing = set(expected_subset) - set(actual_set)
if not missing:
return
self.fail('Missing elements %s\nExpected: %s\nActual: %s' % (
missing, expected_subset, actual_set), msg)
def assertNoCommonElements(self, expected_seq, actual_seq, msg=None):
"""Checks whether actual iterable and expected iterable are disjoint."""
common = set(expected_seq) & set(actual_seq)
if not common:
return
self.fail('Common elements %s\nExpected: %s\nActual: %s' % (
common, expected_seq, actual_seq), msg)
def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
"""Deprecated, please use assertCountEqual instead.
This is equivalent to assertCountEqual.
Args:
expected_seq: A sequence containing elements we are expecting.
actual_seq: The sequence that we are testing.
msg: The message to be printed if the test fails.
"""
super().assertCountEqual(expected_seq, actual_seq, msg)
def assertSameElements(self, expected_seq, actual_seq, msg=None):
"""Asserts that two sequences have the same elements (in any order).
This method, unlike assertCountEqual, doesn't care about any
duplicates in the expected and actual sequences::
# Doesn't raise an AssertionError
assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
If possible, you should use assertCountEqual instead of
assertSameElements.
Args:
expected_seq: A sequence containing elements we are expecting.
actual_seq: The sequence that we are testing.
msg: The message to be printed if the test fails.
"""
# `unittest2.TestCase` used to have assertSameElements, but it was
# removed in favor of assertItemsEqual. As there's a unit test
# that explicitly checks this behavior, I am leaving this method
# alone.
# Fail on strings: empirically, passing strings to this test method
# is almost always a bug. If comparing the character sets of two strings
# is desired, cast the inputs to sets or lists explicitly.
if (isinstance(expected_seq, _TEXT_OR_BINARY_TYPES) or
isinstance(actual_seq, _TEXT_OR_BINARY_TYPES)):
self.fail('Passing string/bytes to assertSameElements is usually a bug. '
'Did you mean to use assertEqual?\n'
'Expected: %s\nActual: %s' % (expected_seq, actual_seq))
try:
expected = dict([(element, None) for element in expected_seq])
actual = dict([(element, None) for element in actual_seq])
missing = [element for element in expected if element not in actual]
unexpected = [element for element in actual if element not in expected]
missing.sort()
unexpected.sort()
except TypeError:
# Fall back to slower list-compare if any of the objects are
# not hashable.
expected = list(expected_seq)
actual = list(actual_seq)
expected.sort()
actual.sort()
missing, unexpected = _sorted_list_difference(expected, actual)
errors = []
if msg:
errors.extend((msg, ':\n'))
if missing:
errors.append('Expected, but missing:\n %r\n' % missing)
if unexpected:
errors.append('Unexpected, but present:\n %r\n' % unexpected)
if missing or unexpected:
self.fail(''.join(errors))
# unittest.TestCase.assertMultiLineEqual works very similarly, but it
# has a different error format. However, I find this slightly more readable.
def assertMultiLineEqual(self, first, second, msg=None, **kwargs):
"""Asserts that two multi-line strings are equal."""
assert isinstance(first,
str), ('First argument is not a string: %r' % (first,))
assert isinstance(second,
str), ('Second argument is not a string: %r' % (second,))
line_limit = kwargs.pop('line_limit', 0)
if kwargs:
raise TypeError('Unexpected keyword args {}'.format(tuple(kwargs)))
if first == second:
return
if msg:
failure_message = [msg + ':\n']
else:
failure_message = ['\n']
if line_limit:
line_limit += len(failure_message)
for line in difflib.ndiff(first.splitlines(True), second.splitlines(True)):
failure_message.append(line)
if not line.endswith('\n'):
failure_message.append('\n')
if line_limit and len(failure_message) > line_limit:
n_omitted = len(failure_message) - line_limit
failure_message = failure_message[:line_limit]
failure_message.append(
'(... and {} more delta lines omitted for brevity.)\n'.format(
n_omitted))
raise self.failureException(''.join(failure_message))
def assertBetween(self, value, minv, maxv, msg=None):
"""Asserts that value is between minv and maxv (inclusive)."""
msg = self._formatMessage(msg,
'"%r" unexpectedly not between "%r" and "%r"' %
(value, minv, maxv))
self.assertTrue(minv <= value, msg)
self.assertTrue(maxv >= value, msg)
def assertRegexMatch(self, actual_str, regexes, message=None):
r"""Asserts that at least one regex in regexes matches str.
If possible you should use `assertRegex`, which is a simpler
version of this method. `assertRegex` takes a single regular
expression (a string or re compiled object) instead of a list.
Notes:
1. This function uses substring matching, i.e. the matching
succeeds if *any* substring of the error message matches *any*
regex in the list. This is more convenient for the user than
full-string matching.
2. If regexes is the empty list, the matching will always fail.
3. Use regexes=[''] for a regex that will always pass.
4. '.' matches any single character *except* the newline. To
match any character, use '(.|\n)'.
5. '^' matches the beginning of each line, not just the beginning
of the string. Similarly, '$' matches the end of each line.
6. An exception will be thrown if regexes contains an invalid
regex.
Args:
actual_str: The string we try to match with the items in regexes.
regexes: The regular expressions we want to match against str.
See "Notes" above for detailed notes on how this is interpreted.
message: The message to be printed if the test fails.
"""
if isinstance(regexes, _TEXT_OR_BINARY_TYPES):
self.fail('regexes is string or bytes; use assertRegex instead.',
message)
if not regexes:
self.fail('No regexes specified.', message)
regex_type = type(regexes[0])
for regex in regexes[1:]:
if type(regex) is not regex_type: # pylint: disable=unidiomatic-typecheck
self.fail('regexes list must all be the same type.', message)
if regex_type is bytes and isinstance(actual_str, str):
regexes = [regex.decode('utf-8') for regex in regexes]
regex_type = str
elif regex_type is str and isinstance(actual_str, bytes):
regexes = [regex.encode('utf-8') for regex in regexes]
regex_type = bytes
if regex_type is str:
regex = u'(?:%s)' % u')|(?:'.join(regexes)
elif regex_type is bytes:
regex = b'(?:' + (b')|(?:'.join(regexes)) + b')'
else:
self.fail('Only know how to deal with unicode str or bytes regexes.',
message)
if not re.search(regex, actual_str, re.MULTILINE):
self.fail('"%s" does not contain any of these regexes: %s.' %
(actual_str, regexes), message)
def assertCommandSucceeds(self, command, regexes=(b'',), env=None,
close_fds=True, msg=None):
"""Asserts that a shell command succeeds (i.e. exits with code 0).
Args:
command: List or string representing the command to run.
regexes: List of regular expression byte strings that match success.
env: Dictionary of environment variable settings. If None, no environment
variables will be set for the child process. This is to make tests
more hermetic. NOTE: this behavior is different than the standard
subprocess module.
close_fds: Whether or not to close all open fd's in the child after
forking.
msg: Optional message to report on failure.
"""
(ret_code, err) = get_command_stderr(command, env, close_fds)
# We need bytes regexes here because `err` is bytes.
# Accommodate code which listed their output regexes w/o the b'' prefix by
# converting them to bytes for the user.
if isinstance(regexes[0], str):
regexes = [regex.encode('utf-8') for regex in regexes]
command_string = get_command_string(command)
self.assertEqual(
ret_code, 0,
self._formatMessage(msg,
'Running command\n'
'%s failed with error code %s and message\n'
'%s' % (_quote_long_string(command_string),
ret_code,
_quote_long_string(err)))
)
self.assertRegexMatch(
err,
regexes,
message=self._formatMessage(
msg,
'Running command\n'
'%s failed with error code %s and message\n'
'%s which matches no regex in %s' % (
_quote_long_string(command_string),
ret_code,
_quote_long_string(err),
regexes)))
def assertCommandFails(self, command, regexes, env=None, close_fds=True,
msg=None):
"""Asserts a shell command fails and the error matches a regex in a list.
Args:
command: List or string representing the command to run.
regexes: the list of regular expression strings.
env: Dictionary of environment variable settings. If None, no environment
variables will be set for the child process. This is to make tests
more hermetic. NOTE: this behavior is different than the standard
subprocess module.
close_fds: Whether or not to close all open fd's in the child after
forking.
msg: Optional message to report on failure.
"""
(ret_code, err) = get_command_stderr(command, env, close_fds)
# We need bytes regexes here because `err` is bytes.
# Accommodate code which listed their output regexes w/o the b'' prefix by
# converting them to bytes for the user.
if isinstance(regexes[0], str):
regexes = [regex.encode('utf-8') for regex in regexes]
command_string = get_command_string(command)
self.assertNotEqual(
ret_code, 0,
self._formatMessage(msg, 'The following command succeeded '
'while expected to fail:\n%s' %
_quote_long_string(command_string)))
self.assertRegexMatch(
err,
regexes,
message=self._formatMessage(
msg,
'Running command\n'
'%s failed with error code %s and message\n'
'%s which matches no regex in %s' % (
_quote_long_string(command_string),
ret_code,
_quote_long_string(err),
regexes)))
class _AssertRaisesContext(object):
def __init__(self, expected_exception, test_case, test_func, msg=None):
self.expected_exception = expected_exception
self.test_case = test_case
self.test_func = test_func
self.msg = msg
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, tb):
if exc_type is None:
self.test_case.fail(self.expected_exception.__name__ + ' not raised',
self.msg)
if not issubclass(exc_type, self.expected_exception):
return False
self.test_func(exc_value)
if exc_value:
self.exception = exc_value.with_traceback(None)
return True
@typing.overload
def assertRaisesWithPredicateMatch(
self, expected_exception, predicate) -> _AssertRaisesContext:
# The purpose of this return statement is to work around
# https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
return self._AssertRaisesContext(None, None, None)
@typing.overload
def assertRaisesWithPredicateMatch(
self, expected_exception, predicate, callable_obj: Callable[..., Any],
*args, **kwargs) -> None:
# The purpose of this return statement is to work around
# https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
return self._AssertRaisesContext(None, None, None)
def assertRaisesWithPredicateMatch(self, expected_exception, predicate,
callable_obj=None, *args, **kwargs):
"""Asserts that exception is thrown and predicate(exception) is true.
Args:
expected_exception: Exception class expected to be raised.
predicate: Function of one argument that inspects the passed-in exception
and returns True (success) or False (please fail the test).
callable_obj: Function to be called.
*args: Extra args.
**kwargs: Extra keyword args.
Returns:
A context manager if callable_obj is None. Otherwise, None.
Raises:
self.failureException if callable_obj does not raise a matching exception.
"""
def Check(err):
self.assertTrue(predicate(err),
'%r does not match predicate %r' % (err, predicate))
context = self._AssertRaisesContext(expected_exception, self, Check)
if callable_obj is None:
return context
with context:
callable_obj(*args, **kwargs)
@typing.overload
def assertRaisesWithLiteralMatch(
self, expected_exception, expected_exception_message
) -> _AssertRaisesContext:
# The purpose of this return statement is to work around
# https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
return self._AssertRaisesContext(None, None, None)
@typing.overload
def assertRaisesWithLiteralMatch(
self, expected_exception, expected_exception_message,
callable_obj: Callable[..., Any], *args, **kwargs) -> None:
# The purpose of this return statement is to work around
# https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
return self._AssertRaisesContext(None, None, None)
def assertRaisesWithLiteralMatch(self, expected_exception,
expected_exception_message,
callable_obj=None, *args, **kwargs):
"""Asserts that the message in a raised exception equals the given string.
Unlike assertRaisesRegex, this method takes a literal string, not
a regular expression.
with self.assertRaisesWithLiteralMatch(ExType, 'message'):
DoSomething()
Args:
expected_exception: Exception class expected to be raised.
expected_exception_message: String message expected in the raised
exception. For a raise exception e, expected_exception_message must
equal str(e).
callable_obj: Function to be called, or None to return a context.
*args: Extra args.
**kwargs: Extra kwargs.
Returns:
A context manager if callable_obj is None. Otherwise, None.
Raises:
self.failureException if callable_obj does not raise a matching exception.
"""
def Check(err):
actual_exception_message = str(err)
self.assertTrue(expected_exception_message == actual_exception_message,
'Exception message does not match.\n'
'Expected: %r\n'
'Actual: %r' % (expected_exception_message,
actual_exception_message))
context = self._AssertRaisesContext(expected_exception, self, Check)
if callable_obj is None:
return context
with context:
callable_obj(*args, **kwargs)
def assertContainsInOrder(self, strings, target, msg=None):
"""Asserts that the strings provided are found in the target in order.
This may be useful for checking HTML output.
Args:
strings: A list of strings, such as [ 'fox', 'dog' ]
target: A target string in which to look for the strings, such as
'The quick brown fox jumped over the lazy dog'.
msg: Optional message to report on failure.
"""
if isinstance(strings, (bytes, unicode if str is bytes else str)):
strings = (strings,)
current_index = 0
last_string = None
for string in strings:
index = target.find(str(string), current_index)
if index == -1 and current_index == 0:
self.fail("Did not find '%s' in '%s'" %
(string, target), msg)
elif index == -1:
self.fail("Did not find '%s' after '%s' in '%s'" %
(string, last_string, target), msg)
last_string = string
current_index = index
def assertContainsSubsequence(self, container, subsequence, msg=None):
"""Asserts that "container" contains "subsequence" as a subsequence.
Asserts that "container" contains all the elements of "subsequence", in
order, but possibly with other elements interspersed. For example, [1, 2, 3]
is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
Args:
container: the list we're testing for subsequence inclusion.
subsequence: the list we hope will be a subsequence of container.
msg: Optional message to report on failure.
"""
first_nonmatching = None
reversed_container = list(reversed(container))
subsequence = list(subsequence)
for e in subsequence:
if e not in reversed_container:
first_nonmatching = e
break
while e != reversed_container.pop():
pass
if first_nonmatching is not None:
self.fail('%s not a subsequence of %s. First non-matching element: %s' %
(subsequence, container, first_nonmatching), msg)
def assertContainsExactSubsequence(self, container, subsequence, msg=None):
"""Asserts that "container" contains "subsequence" as an exact subsequence.
Asserts that "container" contains all the elements of "subsequence", in
order, and without other elements interspersed. For example, [1, 2, 3] is an
exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
Args:
container: the list we're testing for subsequence inclusion.
subsequence: the list we hope will be an exact subsequence of container.
msg: Optional message to report on failure.
"""
container = list(container)
subsequence = list(subsequence)
longest_match = 0
for start in range(1 + len(container) - len(subsequence)):
if longest_match == len(subsequence):
break
index = 0
while (index < len(subsequence) and
subsequence[index] == container[start + index]):
index += 1
longest_match = max(longest_match, index)
if longest_match < len(subsequence):
self.fail('%s not an exact subsequence of %s. '
'Longest matching prefix: %s' %
(subsequence, container, subsequence[:longest_match]), msg)
def assertTotallyOrdered(self, *groups, **kwargs):
"""Asserts that total ordering has been implemented correctly.
For example, say you have a class A that compares only on its attribute x.
Comparators other than ``__lt__`` are omitted for brevity::
class A(object):
def __init__(self, x, y):
self.x = x
self.y = y
def __hash__(self):
return hash(self.x)
def __lt__(self, other):
try:
return self.x < other.x
except AttributeError:
return NotImplemented
assertTotallyOrdered will check that instances can be ordered correctly.
For example::
self.assertTotallyOrdered(
[None], # None should come before everything else.
[1], # Integers sort earlier.
[A(1, 'a')],
[A(2, 'b')], # 2 is after 1.
[A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
[A(4, 'z')],
['foo']) # Strings sort last.
Args:
*groups: A list of groups of elements. Each group of elements is a list
of objects that are equal. The elements in each group must be less
than the elements in the group after it. For example, these groups are
totally ordered: ``[None]``, ``[1]``, ``[2, 2]``, ``[3]``.
**kwargs: optional msg keyword argument can be passed.
"""
def CheckOrder(small, big):
"""Ensures small is ordered before big."""
self.assertFalse(small == big,
self._formatMessage(msg, '%r unexpectedly equals %r' %
(small, big)))
self.assertTrue(small != big,
self._formatMessage(msg, '%r unexpectedly equals %r' %
(small, big)))
self.assertLess(small, big, msg)
self.assertFalse(big < small,
self._formatMessage(msg,
'%r unexpectedly less than %r' %
(big, small)))
self.assertLessEqual(small, big, msg)
self.assertFalse(big <= small, self._formatMessage(
'%r unexpectedly less than or equal to %r' % (big, small), msg
))
self.assertGreater(big, small, msg)
self.assertFalse(small > big,
self._formatMessage(msg,
'%r unexpectedly greater than %r' %
(small, big)))
self.assertGreaterEqual(big, small)
self.assertFalse(small >= big, self._formatMessage(
msg,
'%r unexpectedly greater than or equal to %r' % (small, big)))
def CheckEqual(a, b):
"""Ensures that a and b are equal."""
self.assertEqual(a, b, msg)
self.assertFalse(a != b,
self._formatMessage(msg, '%r unexpectedly unequals %r' %
(a, b)))
# Objects that compare equal must hash to the same value, but this only
# applies if both objects are hashable.
if (isinstance(a, abc.Hashable) and
isinstance(b, abc.Hashable)):
self.assertEqual(
hash(a), hash(b),
self._formatMessage(
msg, 'hash %d of %r unexpectedly not equal to hash %d of %r' %
(hash(a), a, hash(b), b)))
self.assertFalse(a < b,
self._formatMessage(msg,
'%r unexpectedly less than %r' %
(a, b)))
self.assertFalse(b < a,
self._formatMessage(msg,
'%r unexpectedly less than %r' %
(b, a)))
self.assertLessEqual(a, b, msg)
self.assertLessEqual(b, a, msg) # pylint: disable=arguments-out-of-order
self.assertFalse(a > b,
self._formatMessage(msg,
'%r unexpectedly greater than %r' %
(a, b)))
self.assertFalse(b > a,
self._formatMessage(msg,
'%r unexpectedly greater than %r' %
(b, a)))
self.assertGreaterEqual(a, b, msg)
self.assertGreaterEqual(b, a, msg) # pylint: disable=arguments-out-of-order
msg = kwargs.get('msg')
# For every combination of elements, check the order of every pair of
# elements.
for elements in itertools.product(*groups):
elements = list(elements)
for index, small in enumerate(elements[:-1]):
for big in elements[index + 1:]:
CheckOrder(small, big)
# Check that every element in each group is equal.
for group in groups:
for a in group:
CheckEqual(a, a)
for a, b in itertools.product(group, group):
CheckEqual(a, b)
def assertDictEqual(self, a, b, msg=None):
"""Raises AssertionError if a and b are not equal dictionaries.
Args:
a: A dict, the expected value.
b: A dict, the actual value.
msg: An optional str, the associated message.
Raises:
AssertionError: if the dictionaries are not equal.
"""
self.assertIsInstance(a, dict, self._formatMessage(
msg,
'First argument is not a dictionary'
))
self.assertIsInstance(b, dict, self._formatMessage(
msg,
'Second argument is not a dictionary'
))
def Sorted(list_of_items):
try:
return sorted(list_of_items) # In 3.3, unordered are possible.
except TypeError:
return list_of_items
if a == b:
return
a_items = Sorted(list(a.items()))
b_items = Sorted(list(b.items()))
unexpected = []
missing = []
different = []
safe_repr = unittest.util.safe_repr # pytype: disable=module-attr
def Repr(dikt):
"""Deterministic repr for dict."""
# Sort the entries based on their repr, not based on their sort order,
# which will be non-deterministic across executions, for many types.
entries = sorted((safe_repr(k), safe_repr(v)) for k, v in dikt.items())
return '{%s}' % (', '.join('%s: %s' % pair for pair in entries))
message = ['%s != %s%s' % (Repr(a), Repr(b), ' (%s)' % msg if msg else '')]
# The standard library default output confounds lexical difference with
# value difference; treat them separately.
for a_key, a_value in a_items:
if a_key not in b:
missing.append((a_key, a_value))
elif a_value != b[a_key]:
different.append((a_key, a_value, b[a_key]))
for b_key, b_value in b_items:
if b_key not in a:
unexpected.append((b_key, b_value))
if unexpected:
message.append(
'Unexpected, but present entries:\n%s' % ''.join(
'%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in unexpected))
if different:
message.append(
'repr() of differing entries:\n%s' % ''.join(
'%s: %s != %s\n' % (safe_repr(k), safe_repr(a_value),
safe_repr(b_value))
for k, a_value, b_value in different))
if missing:
message.append(
'Missing entries:\n%s' % ''.join(
('%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in missing)))
raise self.failureException('\n'.join(message))
def assertDataclassEqual(self, first, second, msg=None):
"""Asserts two dataclasses are equal with more informative errors.
Arguments must both be dataclasses. This compares equality of individual
fields and takes care to not compare fields that are marked as
non-comparable. It gives per field differences, which are easier to parse
than the comparison of the string representations from assertEqual.
In cases where the dataclass has a custom __eq__, and it is defined in a
way that is inconsistent with equality of comparable fields, we raise an
exception without further trying to figure out how they are different.
Args:
first: A dataclass, the first value.
second: A dataclass, the second value.
msg: An optional str, the associated message.
Raises:
AssertionError: if the dataclasses are not equal.
"""
if not dataclasses.is_dataclass(first) or isinstance(first, type):
raise self.failureException('First argument is not a dataclass instance.')
if not dataclasses.is_dataclass(second) or isinstance(second, type):
raise self.failureException(
'Second argument is not a dataclass instance.'
)
if first == second:
return
if type(first) is not type(second):
self.fail(
'Found different dataclass types: %s != %s'
% (type(first), type(second)),
msg,
)
# Make sure to skip fields that are marked compare=False.
different = [
(f.name, getattr(first, f.name), getattr(second, f.name))
for f in dataclasses.fields(first)
if f.compare and getattr(first, f.name) != getattr(second, f.name)
]
safe_repr = unittest.util.safe_repr # pytype: disable=module-attr
message = ['%s != %s' % (safe_repr(first), safe_repr(second))]
if different:
message.append('Fields that differ:')
message.extend(
'%s: %s != %s' % (k, safe_repr(first_v), safe_repr(second_v))
for k, first_v, second_v in different
)
else:
message.append(
'Cannot detect difference by examining the fields of the dataclass.'
)
raise self.fail('\n'.join(message), msg)
def assertUrlEqual(self, a, b, msg=None):
"""Asserts that urls are equal, ignoring ordering of query params."""
parsed_a = parse.urlparse(a)
parsed_b = parse.urlparse(b)
self.assertEqual(parsed_a.scheme, parsed_b.scheme, msg)
self.assertEqual(parsed_a.netloc, parsed_b.netloc, msg)
self.assertEqual(parsed_a.path, parsed_b.path, msg)
self.assertEqual(parsed_a.fragment, parsed_b.fragment, msg)
self.assertEqual(sorted(parsed_a.params.split(';')),
sorted(parsed_b.params.split(';')), msg)
self.assertDictEqual(
parse.parse_qs(parsed_a.query, keep_blank_values=True),
parse.parse_qs(parsed_b.query, keep_blank_values=True), msg)
def assertSameStructure(self, a, b, aname='a', bname='b', msg=None):
"""Asserts that two values contain the same structural content.
The two arguments should be data trees consisting of trees of dicts and
lists. They will be deeply compared by walking into the contents of dicts
and lists; other items will be compared using the == operator.
If the two structures differ in content, the failure message will indicate
the location within the structures where the first difference is found.
This may be helpful when comparing large structures.
Mixed Sequence and Set types are supported. Mixed Mapping types are
supported, but the order of the keys will not be considered in the
comparison.
Args:
a: The first structure to compare.
b: The second structure to compare.
aname: Variable name to use for the first structure in assertion messages.
bname: Variable name to use for the second structure.
msg: Additional text to include in the failure message.
"""
# Accumulate all the problems found so we can report all of them at once
# rather than just stopping at the first
problems = []
_walk_structure_for_problems(a, b, aname, bname, problems,
self.assertEqual, self.failureException)
# Avoid spamming the user toooo much
if self.maxDiff is not None:
max_problems_to_show = self.maxDiff // 80
if len(problems) > max_problems_to_show:
problems = problems[0:max_problems_to_show-1] + ['...']
if problems:
self.fail('; '.join(problems), msg)
def assertJsonEqual(self, first, second, msg=None):
"""Asserts that the JSON objects defined in two strings are equal.
A summary of the differences will be included in the failure message
using assertSameStructure.
Args:
first: A string containing JSON to decode and compare to second.
second: A string containing JSON to decode and compare to first.
msg: Additional text to include in the failure message.
"""
try:
first_structured = json.loads(first)
except ValueError as e:
raise ValueError(self._formatMessage(
msg,
'could not decode first JSON value %s: %s' % (first, e)))
try:
second_structured = json.loads(second)
except ValueError as e:
raise ValueError(self._formatMessage(
msg,
'could not decode second JSON value %s: %s' % (second, e)))
self.assertSameStructure(first_structured, second_structured,
aname='first', bname='second', msg=msg)
def _getAssertEqualityFunc(self, first, second):
# type: (Any, Any) -> Callable[..., None]
try:
return super(TestCase, self)._getAssertEqualityFunc(first, second)
except AttributeError:
# This is a workaround if unittest.TestCase.__init__ was never run.
# It usually means that somebody created a subclass just for the
# assertions and has overridden __init__. "assertTrue" is a safe
# value that will not make __init__ raise a ValueError.
test_method = getattr(self, '_testMethodName', 'assertTrue')
super(TestCase, self).__init__(test_method)
return super(TestCase, self)._getAssertEqualityFunc(first, second)
def fail(self, msg=None, user_msg=None) -> NoReturn:
"""Fail immediately with the given standard message and user message."""
return super(TestCase, self).fail(self._formatMessage(user_msg, msg))
def _sorted_list_difference(expected, actual):
# type: (List[_T], List[_T]) -> Tuple[List[_T], List[_T]]
"""Finds elements in only one or the other of two, sorted input lists.
Returns a two-element tuple of lists. The first list contains those
elements in the "expected" list but not in the "actual" list, and the
second contains those elements in the "actual" list but not in the
"expected" list. Duplicate elements in either input list are ignored.
Args:
expected: The list we expected.
actual: The list we actually got.
Returns:
(missing, unexpected)
missing: items in expected that are not in actual.
unexpected: items in actual that are not in expected.
"""
i = j = 0
missing = []
unexpected = []
while True:
try:
e = expected[i]
a = actual[j]
if e < a:
missing.append(e)
i += 1
while expected[i] == e:
i += 1
elif e > a:
unexpected.append(a)
j += 1
while actual[j] == a:
j += 1
else:
i += 1
try:
while expected[i] == e:
i += 1
finally:
j += 1
while actual[j] == a:
j += 1
except IndexError:
missing.extend(expected[i:])
unexpected.extend(actual[j:])
break
return missing, unexpected
def _are_both_of_integer_type(a, b):
# type: (object, object) -> bool
return isinstance(a, int) and isinstance(b, int)
def _are_both_of_sequence_type(a, b):
# type: (object, object) -> bool
return isinstance(a, abc.Sequence) and isinstance(
b, abc.Sequence) and not isinstance(
a, _TEXT_OR_BINARY_TYPES) and not isinstance(b, _TEXT_OR_BINARY_TYPES)
def _are_both_of_set_type(a, b):
# type: (object, object) -> bool
return isinstance(a, abc.Set) and isinstance(b, abc.Set)
def _are_both_of_mapping_type(a, b):
# type: (object, object) -> bool
return isinstance(a, abc.Mapping) and isinstance(
b, abc.Mapping)
def _walk_structure_for_problems(
a, b, aname, bname, problem_list, leaf_assert_equal_func, failure_exception
):
"""The recursive comparison behind assertSameStructure."""
if type(a) != type(b) and not ( # pylint: disable=unidiomatic-typecheck
_are_both_of_integer_type(a, b) or _are_both_of_sequence_type(a, b) or
_are_both_of_set_type(a, b) or _are_both_of_mapping_type(a, b)):
# We do not distinguish between int and long types as 99.99% of Python 2
# code should never care. They collapse into a single type in Python 3.
problem_list.append('%s is a %r but %s is a %r' %
(aname, type(a), bname, type(b)))
# If they have different types there's no point continuing
return
if isinstance(a, abc.Set):
for k in a:
if k not in b:
problem_list.append(
'%s has %r but %s does not' % (aname, k, bname))
for k in b:
if k not in a:
problem_list.append('%s lacks %r but %s has it' % (aname, k, bname))
# NOTE: a or b could be a defaultdict, so we must take care that the traversal
# doesn't modify the data.
elif isinstance(a, abc.Mapping):
for k in a:
if k in b:
_walk_structure_for_problems(
a[k], b[k], '%s[%r]' % (aname, k), '%s[%r]' % (bname, k),
problem_list, leaf_assert_equal_func, failure_exception)
else:
problem_list.append(
"%s has [%r] with value %r but it's missing in %s" %
(aname, k, a[k], bname))
for k in b:
if k not in a:
problem_list.append(
'%s lacks [%r] but %s has it with value %r' %
(aname, k, bname, b[k]))
# Strings/bytes are Sequences but we'll just do those with regular !=
elif (isinstance(a, abc.Sequence) and
not isinstance(a, _TEXT_OR_BINARY_TYPES)):
minlen = min(len(a), len(b))
for i in range(minlen):
_walk_structure_for_problems(
a[i], b[i], '%s[%d]' % (aname, i), '%s[%d]' % (bname, i),
problem_list, leaf_assert_equal_func, failure_exception)
for i in range(minlen, len(a)):
problem_list.append('%s has [%i] with value %r but %s does not' %
(aname, i, a[i], bname))
for i in range(minlen, len(b)):
problem_list.append('%s lacks [%i] but %s has it with value %r' %
(aname, i, bname, b[i]))
else:
try:
leaf_assert_equal_func(a, b)
except failure_exception:
problem_list.append('%s is %r but %s is %r' % (aname, a, bname, b))
def get_command_string(command):
"""Returns an escaped string that can be used as a shell command.
Args:
command: List or string representing the command to run.
Returns:
A string suitable for use as a shell command.
"""
if isinstance(command, str):
return command
else:
if os.name == 'nt':
return ' '.join(command)
else:
# The following is identical to Python 3's shlex.quote function.
command_string = ''
for word in command:
# Single quote word, and replace each ' in word with '"'"'
command_string += "'" + word.replace("'", "'\"'\"'") + "' "
return command_string[:-1]
def get_command_stderr(command, env=None, close_fds=True):
"""Runs the given shell command and returns a tuple.
Args:
command: List or string representing the command to run.
env: Dictionary of environment variable settings. If None, no environment
variables will be set for the child process. This is to make tests
more hermetic. NOTE: this behavior is different than the standard
subprocess module.
close_fds: Whether or not to close all open fd's in the child after forking.
On Windows, this is ignored and close_fds is always False.
Returns:
Tuple of (exit status, text printed to stdout and stderr by the command).
"""
if env is None: env = {}
if os.name == 'nt':
# Windows does not support setting close_fds to True while also redirecting
# standard handles.
close_fds = False
use_shell = isinstance(command, str)
process = subprocess.Popen(
command,
close_fds=close_fds,
env=env,
shell=use_shell,
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE)
output = process.communicate()[0]
exit_status = process.wait()
return (exit_status, output)
def _quote_long_string(s):
# type: (Union[Text, bytes, bytearray]) -> Text
"""Quotes a potentially multi-line string to make the start and end obvious.
Args:
s: A string.
Returns:
The quoted string.
"""
if isinstance(s, (bytes, bytearray)):
try:
s = s.decode('utf-8')
except UnicodeDecodeError:
s = str(s)
return ('8<-----------\n' +
s + '\n' +
'----------->8\n')
def print_python_version():
# type: () -> None
# Having this in the test output logs by default helps debugging when all
# you've got is the log and no other idea of which Python was used.
sys.stderr.write('Running tests under Python {0[0]}.{0[1]}.{0[2]}: '
'{1}\n'.format(
sys.version_info,
sys.executable if sys.executable else 'embedded.'))
def main(*args, **kwargs):
# type: (Text, Any) -> None
"""Executes a set of Python unit tests.
Usually this function is called without arguments, so the
unittest.TestProgram instance will get created with the default settings,
so it will run all test methods of all TestCase classes in the ``__main__``
module.
Args:
*args: Positional arguments passed through to
``unittest.TestProgram.__init__``.
**kwargs: Keyword arguments passed through to
``unittest.TestProgram.__init__``.
"""
print_python_version()
_run_in_app(run_tests, args, kwargs)
def _is_in_app_main():
# type: () -> bool
"""Returns True iff app.run is active."""
f = sys._getframe().f_back # pylint: disable=protected-access
while f:
if f.f_code == app.run.__code__:
return True
f = f.f_back
return False
def _register_sigterm_with_faulthandler():
# type: () -> None
"""Have faulthandler dump stacks on SIGTERM. Useful to diagnose timeouts."""
if getattr(faulthandler, 'register', None):
# faulthandler.register is not available on Windows.
# faulthandler.enable() is already called by app.run.
try:
faulthandler.register(signal.SIGTERM, chain=True) # pytype: disable=module-attr
except Exception as e: # pylint: disable=broad-except
sys.stderr.write('faulthandler.register(SIGTERM) failed '
'%r; ignoring.\n' % e)
def _run_in_app(function, args, kwargs):
# type: (Callable[..., None], Sequence[Text], Mapping[Text, Any]) -> None
"""Executes a set of Python unit tests, ensuring app.run.
This is a private function, users should call absltest.main().
_run_in_app calculates argv to be the command-line arguments of this program
(without the flags), sets the default of FLAGS.alsologtostderr to True,
then it calls function(argv, args, kwargs), making sure that `function'
will get called within app.run(). _run_in_app does this by checking whether
it is called by app.run(), or by calling app.run() explicitly.
The reason why app.run has to be ensured is to make sure that
flags are parsed and stripped properly, and other initializations done by
the app module are also carried out, no matter if absltest.run() is called
from within or outside app.run().
If _run_in_app is called from within app.run(), then it will reparse
sys.argv and pass the result without command-line flags into the argv
argument of `function'. The reason why this parsing is needed is that
__main__.main() calls absltest.main() without passing its argv. So the
only way _run_in_app could get to know the argv without the flags is that
it reparses sys.argv.
_run_in_app changes the default of FLAGS.alsologtostderr to True so that the
test program's stderr will contain all the log messages unless otherwise
specified on the command-line. This overrides any explicit assignment to
FLAGS.alsologtostderr by the test program prior to the call to _run_in_app()
(e.g. in __main__.main).
Please note that _run_in_app (and the function it calls) is allowed to make
changes to kwargs.
Args:
function: absltest.run_tests or a similar function. It will be called as
function(argv, args, kwargs) where argv is a list containing the
elements of sys.argv without the command-line flags.
args: Positional arguments passed through to unittest.TestProgram.__init__.
kwargs: Keyword arguments passed through to unittest.TestProgram.__init__.
"""
if _is_in_app_main():
_register_sigterm_with_faulthandler()
# Change the default of alsologtostderr from False to True, so the test
# programs's stderr will contain all the log messages.
# If --alsologtostderr=false is specified in the command-line, or user
# has called FLAGS.alsologtostderr = False before, then the value is kept
# False.
FLAGS.set_default('alsologtostderr', True)
# Here we only want to get the `argv` without the flags. To avoid any
# side effects of parsing flags, we temporarily stub out the `parse` method
stored_parse_methods = {}
noop_parse = lambda _: None
for name in FLAGS:
# Avoid any side effects of parsing flags.
stored_parse_methods[name] = FLAGS[name].parse
# This must be a separate loop since multiple flag names (short_name=) can
# point to the same flag object.
for name in FLAGS:
FLAGS[name].parse = noop_parse
try:
argv = FLAGS(sys.argv)
finally:
for name in FLAGS:
FLAGS[name].parse = stored_parse_methods[name]
sys.stdout.flush()
function(argv, args, kwargs)
else:
# Send logging to stderr. Use --alsologtostderr instead of --logtostderr
# in case tests are reading their own logs.
FLAGS.set_default('alsologtostderr', True)
def main_function(argv):
_register_sigterm_with_faulthandler()
function(argv, args, kwargs)
app.run(main=main_function)
def _is_suspicious_attribute(testCaseClass, name):
# type: (Type, Text) -> bool
"""Returns True if an attribute is a method named like a test method."""
if name.startswith('Test') and len(name) > 4 and name[4].isupper():
attr = getattr(testCaseClass, name)
if inspect.isfunction(attr) or inspect.ismethod(attr):
args = inspect.getfullargspec(attr)
return (len(args.args) == 1 and args.args[0] == 'self' and
args.varargs is None and args.varkw is None and
not args.kwonlyargs)
return False
def skipThisClass(reason):
# type: (Text) -> Callable[[_T], _T]
"""Skip tests in the decorated TestCase, but not any of its subclasses.
This decorator indicates that this class should skip all its tests, but not
any of its subclasses. Useful for if you want to share testMethod or setUp
implementations between a number of concrete testcase classes.
Example usage, showing how you can share some common test methods between
subclasses. In this example, only ``BaseTest`` will be marked as skipped, and
not RealTest or SecondRealTest::
@absltest.skipThisClass("Shared functionality")
class BaseTest(absltest.TestCase):
def test_simple_functionality(self):
self.assertEqual(self.system_under_test.method(), 1)
class RealTest(BaseTest):
def setUp(self):
super().setUp()
self.system_under_test = MakeSystem(argument)
def test_specific_behavior(self):
...
class SecondRealTest(BaseTest):
def setUp(self):
super().setUp()
self.system_under_test = MakeSystem(other_arguments)
def test_other_behavior(self):
...
Args:
reason: The reason we have a skip in place. For instance: 'shared test
methods' or 'shared assertion methods'.
Returns:
Decorator function that will cause a class to be skipped.
"""
if isinstance(reason, type):
raise TypeError('Got {!r}, expected reason as string'.format(reason))
def _skip_class(test_case_class):
if not issubclass(test_case_class, unittest.TestCase):
raise TypeError(
'Decorating {!r}, expected TestCase subclass'.format(test_case_class))
# Only shadow the setUpClass method if it is directly defined. If it is
# in the parent class we invoke it via a super() call instead of holding
# a reference to it.
shadowed_setupclass = test_case_class.__dict__.get('setUpClass', None)
@classmethod
def replacement_setupclass(cls, *args, **kwargs):
# Skip this class if it is the one that was decorated with @skipThisClass
if cls is test_case_class:
raise SkipTest(reason)
if shadowed_setupclass:
# Pass along `cls` so the MRO chain doesn't break.
# The original method is a `classmethod` descriptor, which can't
# be directly called, but `__func__` has the underlying function.
return shadowed_setupclass.__func__(cls, *args, **kwargs)
else:
# Because there's no setUpClass() defined directly on test_case_class,
# we call super() ourselves to continue execution of the inheritance
# chain.
return super(test_case_class, cls).setUpClass(*args, **kwargs)
test_case_class.setUpClass = replacement_setupclass
return test_case_class
return _skip_class
class TestLoader(unittest.TestLoader):
"""A test loader which supports common test features.
Supported features include:
* Banning untested methods with test-like names: methods attached to this
testCase with names starting with `Test` are ignored by the test runner,
and often represent mistakenly-omitted test cases. This loader will raise
a TypeError when attempting to load a TestCase with such methods.
* Randomization of test case execution order (optional).
"""
_ERROR_MSG = textwrap.dedent("""Method '%s' is named like a test case but
is not one. This is often a bug. If you want it to be a test method,
name it with 'test' in lowercase. If not, rename the method to not begin
with 'Test'.""")
def __init__(self, *args, **kwds):
super(TestLoader, self).__init__(*args, **kwds)
seed = _get_default_randomize_ordering_seed()
if seed:
self._randomize_ordering_seed = seed
self._random = random.Random(self._randomize_ordering_seed)
else:
self._randomize_ordering_seed = None
self._random = None
def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
"""Validates and returns a (possibly randomized) list of test case names."""
for name in dir(testCaseClass):
if _is_suspicious_attribute(testCaseClass, name):
raise TypeError(TestLoader._ERROR_MSG % name)
names = list(super(TestLoader, self).getTestCaseNames(testCaseClass))
if self._randomize_ordering_seed is not None:
logging.info(
'Randomizing test order with seed: %d', self._randomize_ordering_seed)
logging.info(
'To reproduce this order, re-run with '
'--test_randomize_ordering_seed=%d', self._randomize_ordering_seed)
self._random.shuffle(names)
return names
def get_default_xml_output_filename():
# type: () -> Optional[Text]
if os.environ.get('XML_OUTPUT_FILE'):
return os.environ['XML_OUTPUT_FILE']
elif os.environ.get('RUNNING_UNDER_TEST_DAEMON'):
return os.path.join(os.path.dirname(TEST_TMPDIR.value), 'test_detail.xml')
elif os.environ.get('TEST_XMLOUTPUTDIR'):
return os.path.join(
os.environ['TEST_XMLOUTPUTDIR'],
os.path.splitext(os.path.basename(sys.argv[0]))[0] + '.xml')
def _setup_filtering(argv: MutableSequence[str]) -> bool:
"""Implements the bazel test filtering protocol.
The following environment variable is used in this method:
TESTBRIDGE_TEST_ONLY: string, if set, is forwarded to the unittest
framework to use as a test filter. Its value is split with shlex, then:
1. On Python 3.6 and before, split values are passed as positional
arguments on argv.
2. On Python 3.7+, split values are passed to unittest's `-k` flag. Tests
are matched by glob patterns or substring. See
https://docs.python.org/3/library/unittest.html#cmdoption-unittest-k
Args:
argv: the argv to mutate in-place.
Returns:
Whether test filtering is requested.
"""
test_filter = os.environ.get('TESTBRIDGE_TEST_ONLY')
if argv is None or not test_filter:
return False
filters = shlex.split(test_filter)
if sys.version_info[:2] >= (3, 7):
filters = ['-k=' + test_filter for test_filter in filters]
argv[1:1] = filters
return True
def _setup_test_runner_fail_fast(argv):
# type: (MutableSequence[Text]) -> None
"""Implements the bazel test fail fast protocol.
The following environment variable is used in this method:
TESTBRIDGE_TEST_RUNNER_FAIL_FAST=<1|0>
If set to 1, --failfast is passed to the unittest framework to return upon
first failure.
Args:
argv: the argv to mutate in-place.
"""
if argv is None:
return
if os.environ.get('TESTBRIDGE_TEST_RUNNER_FAIL_FAST') != '1':
return
argv[1:1] = ['--failfast']
def _setup_sharding(
custom_loader: Optional[unittest.TestLoader] = None,
) -> Tuple[unittest.TestLoader, Optional[int]]:
"""Implements the bazel sharding protocol.
The following environment variables are used in this method:
TEST_SHARD_STATUS_FILE: string, if set, points to a file. We write a blank
file to tell the test runner that this test implements the test sharding
protocol.
TEST_TOTAL_SHARDS: int, if set, sharding is requested.
TEST_SHARD_INDEX: int, must be set if TEST_TOTAL_SHARDS is set. Specifies
the shard index for this instance of the test process. Must satisfy:
0 <= TEST_SHARD_INDEX < TEST_TOTAL_SHARDS.
Args:
custom_loader: A TestLoader to be made sharded.
Returns:
A tuple of ``(test_loader, shard_index)``. ``test_loader`` is for
shard-filtering or the standard test loader depending on the sharding
environment variables. ``shard_index`` is the shard index, or ``None`` when
sharding is not used.
"""
# It may be useful to write the shard file even if the other sharding
# environment variables are not set. Test runners may use this functionality
# to query whether a test binary implements the test sharding protocol.
if 'TEST_SHARD_STATUS_FILE' in os.environ:
try:
with open(os.environ['TEST_SHARD_STATUS_FILE'], 'w') as f:
f.write('')
except IOError:
sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.'
% os.environ['TEST_SHARD_STATUS_FILE'])
sys.exit(1)
base_loader = custom_loader or TestLoader()
if 'TEST_TOTAL_SHARDS' not in os.environ:
# Not using sharding, use the expected test loader.
return base_loader, None
total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
shard_index = int(os.environ['TEST_SHARD_INDEX'])
if shard_index < 0 or shard_index >= total_shards:
sys.stderr.write('ERROR: Bad sharding values. index=%d, total=%d\n' %
(shard_index, total_shards))
sys.exit(1)
# Replace the original getTestCaseNames with one that returns
# the test case names for this shard.
delegate_get_names = base_loader.getTestCaseNames
bucket_iterator = itertools.cycle(range(total_shards))
def getShardedTestCaseNames(testCaseClass):
filtered_names = []
# We need to sort the list of tests in order to determine which tests this
# shard is responsible for; however, it's important to preserve the order
# returned by the base loader, e.g. in the case of randomized test ordering.
ordered_names = delegate_get_names(testCaseClass)
for testcase in sorted(ordered_names):
bucket = next(bucket_iterator)
if bucket == shard_index:
filtered_names.append(testcase)
return [x for x in ordered_names if x in filtered_names]
base_loader.getTestCaseNames = getShardedTestCaseNames
return base_loader, shard_index
def _run_and_get_tests_result(
argv: MutableSequence[str],
args: Sequence[Any],
kwargs: MutableMapping[str, Any],
xml_test_runner_class: Type[unittest.TextTestRunner],
) -> Tuple[unittest.TestResult, bool]:
"""Same as run_tests, but it doesn't exit.
Args:
argv: sys.argv with the command-line flags removed from the front, i.e. the
argv with which :func:`app.run()` has called
``__main__.main``. It is passed to
``unittest.TestProgram.__init__(argv=)``, which does its own flag parsing.
It is ignored if kwargs contains an argv entry.
args: Positional arguments passed through to
``unittest.TestProgram.__init__``.
kwargs: Keyword arguments passed through to
``unittest.TestProgram.__init__``.
xml_test_runner_class: The type of the test runner class.
Returns:
A tuple of ``(test_result, fail_when_no_tests_ran)``.
``fail_when_no_tests_ran`` indicates whether the test should fail when
no tests ran.
"""
# The entry from kwargs overrides argv.
argv = kwargs.pop('argv', argv)
if sys.version_info[:2] >= (3, 12):
# Python 3.12 unittest changed the behavior from PASS to FAIL in
# https://github.com/python/cpython/pull/102051. absltest follows this.
fail_when_no_tests_ran = True
else:
# Historically, absltest and unittest before Python 3.12 passes if no tests
# ran.
fail_when_no_tests_ran = False
# Set up test filtering if requested in environment.
if _setup_filtering(argv):
# When test filtering is requested, ideally we also want to fail when no
# tests ran. However, the test filters are usually done when running bazel.
# When you run multiple targets, e.g. `bazel test //my_dir/...
# --test_filter=MyTest`, you don't necessarily want individual tests to fail
# because no tests match in that particular target.
# Due to this use case, we don't fail when test filtering is requested via
# the environment variable from bazel.
fail_when_no_tests_ran = False
# Set up --failfast as requested in environment
_setup_test_runner_fail_fast(argv)
# Shard the (default or custom) loader if sharding is turned on.
kwargs['testLoader'], shard_index = _setup_sharding(
kwargs.get('testLoader', None)
)
if shard_index is not None and shard_index > 0:
# When sharding is requested, all the shards except the first one shall not
# fail when no tests ran. This happens when the shard count is greater than
# the test case count.
fail_when_no_tests_ran = False
# XML file name is based upon (sorted by priority):
# --xml_output_file flag, XML_OUTPUT_FILE variable,
# TEST_XMLOUTPUTDIR variable or RUNNING_UNDER_TEST_DAEMON variable.
if not FLAGS.xml_output_file:
FLAGS.xml_output_file = get_default_xml_output_filename()
xml_output_file = FLAGS.xml_output_file
xml_buffer = None
if xml_output_file:
xml_output_dir = os.path.dirname(xml_output_file)
if xml_output_dir and not os.path.isdir(xml_output_dir):
try:
os.makedirs(xml_output_dir)
except OSError as e:
# File exists error can occur with concurrent tests
if e.errno != errno.EEXIST:
raise
# Fail early if we can't write to the XML output file. This is so that we
# don't waste people's time running tests that will just fail anyways.
with _open(xml_output_file, 'w'):
pass
# We can reuse testRunner if it supports XML output (e. g. by inheriting
# from xml_reporter.TextAndXMLTestRunner). Otherwise we need to use
# xml_reporter.TextAndXMLTestRunner.
if (kwargs.get('testRunner') is not None
and not hasattr(kwargs['testRunner'], 'set_default_xml_stream')):
sys.stderr.write('WARNING: XML_OUTPUT_FILE or --xml_output_file setting '
'overrides testRunner=%r setting (possibly from --pdb)'
% (kwargs['testRunner']))
# Passing a class object here allows TestProgram to initialize
# instances based on its kwargs and/or parsed command-line args.
kwargs['testRunner'] = xml_test_runner_class
if kwargs.get('testRunner') is None:
kwargs['testRunner'] = xml_test_runner_class
# Use an in-memory buffer (not backed by the actual file) to store the XML
# report, because some tools modify the file (e.g., create a placeholder
# with partial information, in case the test process crashes).
xml_buffer = io.StringIO()
kwargs['testRunner'].set_default_xml_stream(xml_buffer) # pytype: disable=attribute-error
# If we've used a seed to randomize test case ordering, we want to record it
# as a top-level attribute in the `testsuites` section of the XML output.
randomize_ordering_seed = getattr(
kwargs['testLoader'], '_randomize_ordering_seed', None)
setter = getattr(kwargs['testRunner'], 'set_testsuites_property', None)
if randomize_ordering_seed and setter:
setter('test_randomize_ordering_seed', randomize_ordering_seed)
elif kwargs.get('testRunner') is None:
kwargs['testRunner'] = _pretty_print_reporter.TextTestRunner
if FLAGS.pdb_post_mortem:
runner = kwargs['testRunner']
# testRunner can be a class or an instance, which must be tested for
# differently.
# Overriding testRunner isn't uncommon, so only enable the debugging
# integration if the runner claims it does; we don't want to accidentally
# clobber something on the runner.
if ((isinstance(runner, type) and
issubclass(runner, _pretty_print_reporter.TextTestRunner)) or
isinstance(runner, _pretty_print_reporter.TextTestRunner)):
runner.run_for_debugging = True
# Make sure tmpdir exists.
if not os.path.isdir(TEST_TMPDIR.value):
try:
os.makedirs(TEST_TMPDIR.value)
except OSError as e:
# Concurrent test might have created the directory.
if e.errno != errno.EEXIST:
raise
# Let unittest.TestProgram.__init__ do its own argv parsing, e.g. for '-v',
# on argv, which is sys.argv without the command-line flags.
kwargs['argv'] = argv
# Request unittest.TestProgram to not exit. The exit will be handled by
# `absltest.run_tests`.
kwargs['exit'] = False
try:
test_program = unittest.TestProgram(*args, **kwargs)
return test_program.result, fail_when_no_tests_ran
finally:
if xml_buffer:
try:
with _open(xml_output_file, 'w') as f:
f.write(xml_buffer.getvalue())
finally:
xml_buffer.close()
def run_tests(
argv: MutableSequence[Text],
args: Sequence[Any],
kwargs: MutableMapping[Text, Any],
) -> None:
"""Executes a set of Python unit tests.
Most users should call absltest.main() instead of run_tests.
Please note that run_tests should be called from app.run.
Calling absltest.main() would ensure that.
Please note that run_tests is allowed to make changes to kwargs.
Args:
argv: sys.argv with the command-line flags removed from the front, i.e. the
argv with which :func:`app.run()` has called
``__main__.main``. It is passed to
``unittest.TestProgram.__init__(argv=)``, which does its own flag parsing.
It is ignored if kwargs contains an argv entry.
args: Positional arguments passed through to
``unittest.TestProgram.__init__``.
kwargs: Keyword arguments passed through to
``unittest.TestProgram.__init__``.
"""
result, fail_when_no_tests_ran = _run_and_get_tests_result(
argv, args, kwargs, xml_reporter.TextAndXMLTestRunner
)
if fail_when_no_tests_ran and result.testsRun == 0 and not result.skipped:
# Python 3.12 unittest exits with 5 when no tests ran. The exit code 5 comes
# from pytest which does the same thing.
sys.exit(5)
sys.exit(not result.wasSuccessful())
def _rmtree_ignore_errors(path):
# type: (Text) -> None
if os.path.isfile(path):
try:
os.unlink(path)
except OSError:
pass
else:
shutil.rmtree(path, ignore_errors=True)
def _get_first_part(path):
# type: (Text) -> Text
parts = path.split(os.sep, 1)
return parts[0]
abseil-py-2.1.0/absl/testing/flagsaver.py 0000664 0000000 0000000 00000032120 14551576331 0020331 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decorator and context manager for saving and restoring flag values.
There are many ways to save and restore. Always use the most convenient method
for a given use case.
Here are examples of each method. They all call ``do_stuff()`` while
``FLAGS.someflag`` is temporarily set to ``'foo'``::
from absl.testing import flagsaver
# Use a decorator which can optionally override flags via arguments.
@flagsaver.flagsaver(someflag='foo')
def some_func():
do_stuff()
# Use a decorator which can optionally override flags with flagholders.
@flagsaver.flagsaver((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, 23))
def some_func():
do_stuff()
# Use a decorator which does not override flags itself.
@flagsaver.flagsaver
def some_func():
FLAGS.someflag = 'foo'
do_stuff()
# Use a context manager which can optionally override flags via arguments.
with flagsaver.flagsaver(someflag='foo'):
do_stuff()
# Save and restore the flag values yourself.
saved_flag_values = flagsaver.save_flag_values()
try:
FLAGS.someflag = 'foo'
do_stuff()
finally:
flagsaver.restore_flag_values(saved_flag_values)
# Use the parsing version to emulate users providing the flags.
# Note that all flags must be provided as strings (unparsed).
@flagsaver.as_parsed(some_int_flag='123')
def some_func():
# Because the flag was parsed it is considered "present".
assert FLAGS.some_int_flag.present
do_stuff()
# flagsaver.as_parsed() can also be used as a context manager just like
# flagsaver.flagsaver()
with flagsaver.as_parsed(some_int_flag='123'):
do_stuff()
# The flagsaver.as_parsed() interface also supports FlagHolder objects.
@flagsaver.as_parsed((module.FOO_FLAG, 'foo'), (other_mod.BAR_FLAG, '23'))
def some_func():
do_stuff()
# Using as_parsed with a multi_X flag requires a sequence of strings.
@flagsaver.as_parsed(some_multi_int_flag=['123', '456'])
def some_func():
assert FLAGS.some_multi_int_flag.present
do_stuff()
# If a flag name includes non-identifier characters it can be specified like
# so:
@flagsaver.as_parsed(**{'i-like-dashes': 'true'})
def some_func():
do_stuff()
We save and restore a shallow copy of each Flag object's ``__dict__`` attribute.
This preserves all attributes of the flag, such as whether or not it was
overridden from its default value.
WARNING: Currently a flag that is saved and then deleted cannot be restored. An
exception will be raised. However if you *add* a flag after saving flag values,
and then restore flag values, the added flag will be deleted with no errors.
"""
import collections
import functools
import inspect
from typing import overload, Any, Callable, Mapping, Tuple, TypeVar, Type, Sequence, Union
from absl import flags
FLAGS = flags.FLAGS
# The type of pre/post wrapped functions.
_CallableT = TypeVar('_CallableT', bound=Callable)
@overload
def flagsaver(*args: Tuple[flags.FlagHolder, Any],
**kwargs: Any) -> '_FlagOverrider':
...
@overload
def flagsaver(func: _CallableT) -> _CallableT:
...
def flagsaver(*args, **kwargs):
"""The main flagsaver interface. See module doc for usage."""
return _construct_overrider(_FlagOverrider, *args, **kwargs)
@overload
def as_parsed(*args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
**kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
...
@overload
def as_parsed(func: _CallableT) -> _CallableT:
...
def as_parsed(*args, **kwargs):
"""Overrides flags by parsing strings, saves flag state similar to flagsaver.
This function can be used as either a decorator or context manager similar to
flagsaver.flagsaver(). However, where flagsaver.flagsaver() directly sets the
flags to new values, this function will parse the provided arguments as if
they were provided on the command line. Among other things, this will cause
`FLAGS['flag_name'].present == True`.
A note on unparsed input: For many flag types, the unparsed version will be
a single string. However for multi_x (multi_string, multi_integer, multi_enum)
the unparsed version will be a Sequence of strings.
Args:
*args: Tuples of FlagHolders and their unparsed value.
**kwargs: The keyword args are flag names, and the values are unparsed
values.
Returns:
_ParsingFlagOverrider that serves as a context manager or decorator. Will
save previous flag state and parse new flags, then on cleanup it will
restore the previous flag state.
"""
return _construct_overrider(_ParsingFlagOverrider, *args, **kwargs)
# NOTE: the order of these overload declarations matters. The type checker will
# pick the first match which could be incorrect.
@overload
def _construct_overrider(
flag_overrider_cls: Type['_ParsingFlagOverrider'],
*args: Tuple[flags.FlagHolder, Union[str, Sequence[str]]],
**kwargs: Union[str, Sequence[str]]) -> '_ParsingFlagOverrider':
...
@overload
def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
*args: Tuple[flags.FlagHolder, Any],
**kwargs: Any) -> '_FlagOverrider':
...
@overload
def _construct_overrider(flag_overrider_cls: Type['_FlagOverrider'],
func: _CallableT) -> _CallableT:
...
def _construct_overrider(flag_overrider_cls, *args, **kwargs):
"""Handles the args/kwargs returning an instance of flag_overrider_cls.
If flag_overrider_cls is _FlagOverrider then values should be native python
types matching the python types. Otherwise if flag_overrider_cls is
_ParsingFlagOverrider the values should be strings or sequences of strings.
Args:
flag_overrider_cls: The class that will do the overriding.
*args: Tuples of FlagHolder and the new flag value.
**kwargs: Keword args mapping flag name to new flag value.
Returns:
A _FlagOverrider to be used as a decorator or context manager.
"""
if not args:
return flag_overrider_cls(**kwargs)
# args can be [func] if used as `@flagsaver` instead of `@flagsaver(...)`
if len(args) == 1 and callable(args[0]):
if kwargs:
raise ValueError(
"It's invalid to specify both positional and keyword parameters.")
func = args[0]
if inspect.isclass(func):
raise TypeError('@flagsaver.flagsaver cannot be applied to a class.')
return _wrap(flag_overrider_cls, func, {})
# args can be a list of (FlagHolder, value) pairs.
# In which case they augment any specified kwargs.
for arg in args:
if not isinstance(arg, tuple) or len(arg) != 2:
raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
holder, value = arg
if not isinstance(holder, flags.FlagHolder):
raise ValueError('Expected (FlagHolder, value) pair, found %r' % (arg,))
if holder.name in kwargs:
raise ValueError('Cannot set --%s multiple times' % holder.name)
kwargs[holder.name] = value
return flag_overrider_cls(**kwargs)
def save_flag_values(
flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]:
"""Returns copy of flag values as a dict.
Args:
flag_values: FlagValues, the FlagValues instance with which the flag will be
saved. This should almost never need to be overridden.
Returns:
Dictionary mapping keys to values. Keys are flag names, values are
corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``.
"""
return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}
def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]],
flag_values: flags.FlagValues = FLAGS):
"""Restores flag values based on the dictionary of flag values.
Args:
saved_flag_values: {'flag_name': value_dict, ...}
flag_values: FlagValues, the FlagValues instance from which the flag will be
restored. This should almost never need to be overridden.
"""
new_flag_names = list(flag_values)
for name in new_flag_names:
saved = saved_flag_values.get(name)
if saved is None:
# If __dict__ was not saved delete "new" flag.
delattr(flag_values, name)
else:
if flag_values[name].value != saved['_value']:
flag_values[name].value = saved['_value'] # Ensure C++ value is set.
flag_values[name].__dict__ = saved
@overload
def _wrap(flag_overrider_cls: Type['_FlagOverrider'], func: _CallableT,
overrides: Mapping[str, Any]) -> _CallableT:
...
@overload
def _wrap(flag_overrider_cls: Type['_ParsingFlagOverrider'], func: _CallableT,
overrides: Mapping[str, Union[str, Sequence[str]]]) -> _CallableT:
...
def _wrap(flag_overrider_cls, func, overrides):
"""Creates a wrapper function that saves/restores flag values.
Args:
flag_overrider_cls: The class that will be used as a context manager.
func: This will be called between saving flags and restoring flags.
overrides: Flag names mapped to their values. These flags will be set after
saving the original flag state. The type of the values depends on if
_FlagOverrider or _ParsingFlagOverrider was specified.
Returns:
A wrapped version of func.
"""
@functools.wraps(func)
def _flagsaver_wrapper(*args, **kwargs):
"""Wrapper function that saves and restores flags."""
with flag_overrider_cls(**overrides):
return func(*args, **kwargs)
return _flagsaver_wrapper
class _FlagOverrider(object):
"""Overrides flags for the duration of the decorated function call.
It also restores all original values of flags after decorated method
completes.
"""
def __init__(self, **overrides: Any):
self._overrides = overrides
self._saved_flag_values = None
def __call__(self, func: _CallableT) -> _CallableT:
if inspect.isclass(func):
raise TypeError('flagsaver cannot be applied to a class.')
return _wrap(self.__class__, func, self._overrides)
def __enter__(self):
self._saved_flag_values = save_flag_values(FLAGS)
try:
FLAGS._set_attributes(**self._overrides)
except:
# It may fail because of flag validators.
restore_flag_values(self._saved_flag_values, FLAGS)
raise
def __exit__(self, exc_type, exc_value, traceback):
restore_flag_values(self._saved_flag_values, FLAGS)
class _ParsingFlagOverrider(_FlagOverrider):
"""Context manager for overriding flags.
Simulates command line parsing.
This is simlar to _FlagOverrider except that all **overrides should be
strings or sequences of strings, and when context is entered this class calls
.parse(value)
This results in the flags having .present set properly.
"""
def __init__(self, **overrides: Union[str, Sequence[str]]):
for flag_name, new_value in overrides.items():
if isinstance(new_value, str):
continue
if (isinstance(new_value, collections.abc.Sequence) and
all(isinstance(single_value, str) for single_value in new_value)):
continue
raise TypeError(
f'flagsaver.as_parsed() cannot parse {flag_name}. Expected a single '
f'string or sequence of strings but {type(new_value)} was provided.')
super().__init__(**overrides)
def __enter__(self):
self._saved_flag_values = save_flag_values(FLAGS)
try:
for flag_name, unparsed_value in self._overrides.items():
# LINT.IfChange(flag_override_parsing)
FLAGS[flag_name].parse(unparsed_value)
FLAGS[flag_name].using_default_value = False
# LINT.ThenChange()
# Perform the validation on all modified flags. This is something that
# FLAGS._set_attributes() does for you in _FlagOverrider.
for flag_name in self._overrides:
FLAGS._assert_validators(FLAGS[flag_name].validators)
except KeyError as e:
# If a flag doesn't exist, an UnrecognizedFlagError is more specific.
restore_flag_values(self._saved_flag_values, FLAGS)
raise flags.UnrecognizedFlagError('Unknown command line flag.') from e
except:
# It may fail because of flag validators or general parsing issues.
restore_flag_values(self._saved_flag_values, FLAGS)
raise
def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]:
"""Returns a copy of the flag object's ``__dict__``.
It's mostly a shallow copy of the ``__dict__``, except it also does a shallow
copy of the validator list.
Args:
flag: flags.Flag, the flag to copy.
Returns:
A copy of the flag object's ``__dict__``.
"""
copy = flag.__dict__.copy()
copy['_value'] = flag.value # Ensure correct restore for C++ flags.
copy['validators'] = list(flag.validators)
return copy
abseil-py-2.1.0/absl/testing/parameterized.py 0000664 0000000 0000000 00000066237 14551576331 0021233 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Adds support for parameterized tests to Python's unittest TestCase class.
A parameterized test is a method in a test case that is invoked with different
argument tuples.
A simple example::
class AdditionExample(parameterized.TestCase):
@parameterized.parameters(
(1, 2, 3),
(4, 5, 9),
(1, 1, 3))
def testAddition(self, op1, op2, result):
self.assertEqual(result, op1 + op2)
Each invocation is a separate test case and properly isolated just
like a normal test method, with its own setUp/tearDown cycle. In the
example above, there are three separate testcases, one of which will
fail due to an assertion error (1 + 1 != 3).
Parameters for individual test cases can be tuples (with positional parameters)
or dictionaries (with named parameters)::
class AdditionExample(parameterized.TestCase):
@parameterized.parameters(
{'op1': 1, 'op2': 2, 'result': 3},
{'op1': 4, 'op2': 5, 'result': 9},
)
def testAddition(self, op1, op2, result):
self.assertEqual(result, op1 + op2)
If a parameterized test fails, the error message will show the
original test name and the parameters for that test.
The id method of the test, used internally by the unittest framework, is also
modified to show the arguments (but note that the name reported by `id()`
doesn't match the actual test name, see below). To make sure that test names
stay the same across several invocations, object representations like::
>>> class Foo(object):
... pass
>>> repr(Foo())
'<__main__.Foo object at 0x23d8610>'
are turned into ``__main__.Foo``. When selecting a subset of test cases to run
on the command-line, the test cases contain an index suffix for each argument
in the order they were passed to :func:`parameters` (eg. testAddition0,
testAddition1, etc.) This naming scheme is subject to change; for more reliable
and stable names, especially in test logs, use :func:`named_parameters` instead.
Tests using :func:`named_parameters` are similar to :func:`parameters`, except
only tuples or dicts of args are supported. For tuples, the first parameter arg
has to be a string (or an object that returns an apt name when converted via
``str()``). For dicts, a value for the key ``testcase_name`` must be present and
must be a string (or an object that returns an apt name when converted via
``str()``)::
class NamedExample(parameterized.TestCase):
@parameterized.named_parameters(
('Normal', 'aa', 'aaa', True),
('EmptyPrefix', '', 'abc', True),
('BothEmpty', '', '', True))
def testStartsWith(self, prefix, string, result):
self.assertEqual(result, string.startswith(prefix))
class NamedExample(parameterized.TestCase):
@parameterized.named_parameters(
{'testcase_name': 'Normal',
'result': True, 'string': 'aaa', 'prefix': 'aa'},
{'testcase_name': 'EmptyPrefix',
'result': True, 'string': 'abc', 'prefix': ''},
{'testcase_name': 'BothEmpty',
'result': True, 'string': '', 'prefix': ''})
def testStartsWith(self, prefix, string, result):
self.assertEqual(result, string.startswith(prefix))
Named tests also have the benefit that they can be run individually
from the command line::
$ testmodule.py NamedExample.testStartsWithNormal
.
--------------------------------------------------------------------
Ran 1 test in 0.000s
OK
Parameterized Classes
=====================
If invocation arguments are shared across test methods in a single
TestCase class, instead of decorating all test methods
individually, the class itself can be decorated::
@parameterized.parameters(
(1, 2, 3),
(4, 5, 9))
class ArithmeticTest(parameterized.TestCase):
def testAdd(self, arg1, arg2, result):
self.assertEqual(arg1 + arg2, result)
def testSubtract(self, arg1, arg2, result):
self.assertEqual(result - arg1, arg2)
Inputs from Iterables
=====================
If parameters should be shared across several test cases, or are dynamically
created from other sources, a single non-tuple iterable can be passed into
the decorator. This iterable will be used to obtain the test cases::
class AdditionExample(parameterized.TestCase):
@parameterized.parameters(
c.op1, c.op2, c.result for c in testcases
)
def testAddition(self, op1, op2, result):
self.assertEqual(result, op1 + op2)
Single-Argument Test Methods
============================
If a test method takes only one argument, the single arguments must not be
wrapped into a tuple::
class NegativeNumberExample(parameterized.TestCase):
@parameterized.parameters(
-1, -3, -4, -5
)
def testIsNegative(self, arg):
self.assertTrue(IsNegative(arg))
List/tuple as a Single Argument
===============================
If a test method takes a single argument of a list/tuple, it must be wrapped
inside a tuple::
class ZeroSumExample(parameterized.TestCase):
@parameterized.parameters(
([-1, 0, 1], ),
([-2, 0, 2], ),
)
def testSumIsZero(self, arg):
self.assertEqual(0, sum(arg))
Cartesian product of Parameter Values as Parameterized Test Cases
=================================================================
If required to test method over a cartesian product of parameters,
`parameterized.product` may be used to facilitate generation of parameters
test combinations::
class TestModuloExample(parameterized.TestCase):
@parameterized.product(
num=[0, 20, 80],
modulo=[2, 4],
expected=[0]
)
def testModuloResult(self, num, modulo, expected):
self.assertEqual(expected, num % modulo)
This results in 6 test cases being created - one for each combination of the
parameters. It is also possible to supply sequences of keyword argument dicts
as elements of the cartesian product::
@parameterized.product(
(dict(num=5, modulo=3, expected=2),
dict(num=7, modulo=4, expected=3)),
dtype=(int, float)
)
def testModuloResult(self, num, modulo, expected, dtype):
self.assertEqual(expected, dtype(num) % modulo)
This results in 4 test cases being created - for each of the two sets of test
data (supplied as kwarg dicts) and for each of the two data types (supplied as
a named parameter). Multiple keyword argument dicts may be supplied if required.
Async Support
=============
If a test needs to call async functions, it can inherit from both
parameterized.TestCase and another TestCase that supports async calls, such
as [asynctest](https://github.com/Martiusweb/asynctest)::
import asynctest
class AsyncExample(parameterized.TestCase, asynctest.TestCase):
@parameterized.parameters(
('a', 1),
('b', 2),
)
async def testSomeAsyncFunction(self, arg, expected):
actual = await someAsyncFunction(arg)
self.assertEqual(actual, expected)
"""
from collections import abc
import functools
import inspect
import itertools
import re
import types
import unittest
import warnings
from absl.testing import absltest
_ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>')
_NAMED = object()
_ARGUMENT_REPR = object()
_NAMED_DICT_KEY = 'testcase_name'
class NoTestsError(Exception):
"""Raised when parameterized decorators do not generate any tests."""
class DuplicateTestNameError(Exception):
"""Raised when a parameterized test has the same test name multiple times."""
def __init__(self, test_class_name, new_test_name, original_test_name):
super(DuplicateTestNameError, self).__init__(
'Duplicate parameterized test name in {}: generated test name {!r} '
'(generated from {!r}) already exists. Consider using '
'named_parameters() to give your tests unique names and/or renaming '
'the conflicting test method.'.format(
test_class_name, new_test_name, original_test_name))
def _clean_repr(obj):
return _ADDR_RE.sub(r'<\1>', repr(obj))
def _non_string_or_bytes_iterable(obj):
return (isinstance(obj, abc.Iterable) and not isinstance(obj, str) and
not isinstance(obj, bytes))
def _format_parameter_list(testcase_params):
if isinstance(testcase_params, abc.Mapping):
return ', '.join('%s=%s' % (argname, _clean_repr(value))
for argname, value in testcase_params.items())
elif _non_string_or_bytes_iterable(testcase_params):
return ', '.join(map(_clean_repr, testcase_params))
else:
return _format_parameter_list((testcase_params,))
def _async_wrapped(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
return await func(*args, **kwargs)
return wrapper
class _ParameterizedTestIter(object):
"""Callable and iterable class for producing new test cases."""
def __init__(self, test_method, testcases, naming_type, original_name=None):
"""Returns concrete test functions for a test and a list of parameters.
The naming_type is used to determine the name of the concrete
functions as reported by the unittest framework. If naming_type is
_FIRST_ARG, the testcases must be tuples, and the first element must
have a string representation that is a valid Python identifier.
Args:
test_method: The decorated test method.
testcases: (list of tuple/dict) A list of parameter tuples/dicts for
individual test invocations.
naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR.
original_name: The original test method name. When decorated on a test
method, None is passed to __init__ and test_method.__name__ is used.
Note test_method.__name__ might be different than the original defined
test method because of the use of other decorators. A more accurate
value is set by TestGeneratorMetaclass.__new__ later.
"""
self._test_method = test_method
self.testcases = testcases
self._naming_type = naming_type
if original_name is None:
original_name = test_method.__name__
self._original_name = original_name
self.__name__ = _ParameterizedTestIter.__name__
def __call__(self, *args, **kwargs):
raise RuntimeError('You appear to be running a parameterized test case '
'without having inherited from parameterized.'
'TestCase. This is bad because none of '
'your test cases are actually being run. You may also '
'be using another decorator before the parameterized '
'one, in which case you should reverse the order.')
def __iter__(self):
test_method = self._test_method
naming_type = self._naming_type
def make_bound_param_test(testcase_params):
@functools.wraps(test_method)
def bound_param_test(self):
if isinstance(testcase_params, abc.Mapping):
return test_method(self, **testcase_params)
elif _non_string_or_bytes_iterable(testcase_params):
return test_method(self, *testcase_params)
else:
return test_method(self, testcase_params)
if naming_type is _NAMED:
# Signal the metaclass that the name of the test function is unique
# and descriptive.
bound_param_test.__x_use_name__ = True
testcase_name = None
if isinstance(testcase_params, abc.Mapping):
if _NAMED_DICT_KEY not in testcase_params:
raise RuntimeError(
'Dict for named tests must contain key "%s"' % _NAMED_DICT_KEY)
# Create a new dict to avoid modifying the supplied testcase_params.
testcase_name = testcase_params[_NAMED_DICT_KEY]
testcase_params = {
k: v for k, v in testcase_params.items() if k != _NAMED_DICT_KEY
}
elif _non_string_or_bytes_iterable(testcase_params):
if not isinstance(testcase_params[0], str):
raise RuntimeError(
'The first element of named test parameters is the test name '
'suffix and must be a string')
testcase_name = testcase_params[0]
testcase_params = testcase_params[1:]
else:
raise RuntimeError(
'Named tests must be passed a dict or non-string iterable.')
test_method_name = self._original_name
# Support PEP-8 underscore style for test naming if used.
if (test_method_name.startswith('test_')
and testcase_name
and not testcase_name.startswith('_')):
test_method_name += '_'
bound_param_test.__name__ = test_method_name + str(testcase_name)
elif naming_type is _ARGUMENT_REPR:
# If it's a generator, convert it to a tuple and treat them as
# parameters.
if isinstance(testcase_params, types.GeneratorType):
testcase_params = tuple(testcase_params)
# The metaclass creates a unique, but non-descriptive method name for
# _ARGUMENT_REPR tests using an indexed suffix.
# To keep test names descriptive, only the original method name is used.
# To make sure test names are unique, we add a unique descriptive suffix
# __x_params_repr__ for every test.
params_repr = '(%s)' % (_format_parameter_list(testcase_params),)
bound_param_test.__x_params_repr__ = params_repr
else:
raise RuntimeError('%s is not a valid naming type.' % (naming_type,))
bound_param_test.__doc__ = '%s(%s)' % (
bound_param_test.__name__, _format_parameter_list(testcase_params))
if test_method.__doc__:
bound_param_test.__doc__ += '\n%s' % (test_method.__doc__,)
if inspect.iscoroutinefunction(test_method):
return _async_wrapped(bound_param_test)
return bound_param_test
return (make_bound_param_test(c) for c in self.testcases)
def _modify_class(class_object, testcases, naming_type):
assert not getattr(class_object, '_test_params_reprs', None), (
'Cannot add parameters to %s. Either it already has parameterized '
'methods, or its super class is also a parameterized class.' % (
class_object,))
# NOTE: _test_params_repr is private to parameterized.TestCase and it's
# metaclass; do not use it outside of those classes.
class_object._test_params_reprs = test_params_reprs = {}
for name, obj in class_object.__dict__.copy().items():
if (name.startswith(unittest.TestLoader.testMethodPrefix)
and isinstance(obj, types.FunctionType)):
delattr(class_object, name)
methods = {}
_update_class_dict_for_param_test_case(
class_object.__name__, methods, test_params_reprs, name,
_ParameterizedTestIter(obj, testcases, naming_type, name))
for meth_name, meth in methods.items():
setattr(class_object, meth_name, meth)
def _parameter_decorator(naming_type, testcases):
"""Implementation of the parameterization decorators.
Args:
naming_type: The naming type.
testcases: Testcase parameters.
Raises:
NoTestsError: Raised when the decorator generates no tests.
Returns:
A function for modifying the decorated object.
"""
def _apply(obj):
if isinstance(obj, type):
_modify_class(obj, testcases, naming_type)
return obj
else:
return _ParameterizedTestIter(obj, testcases, naming_type)
if (len(testcases) == 1 and
not isinstance(testcases[0], tuple) and
not isinstance(testcases[0], abc.Mapping)):
# Support using a single non-tuple parameter as a list of test cases.
# Note that the single non-tuple parameter can't be Mapping either, which
# means a single dict parameter case.
assert _non_string_or_bytes_iterable(testcases[0]), (
'Single parameter argument must be a non-string non-Mapping iterable')
testcases = testcases[0]
if not isinstance(testcases, abc.Sequence):
testcases = list(testcases)
if not testcases:
raise NoTestsError(
'parameterized test decorators did not generate any tests. '
'Make sure you specify non-empty parameters, '
'and do not reuse generators more than once.')
return _apply
def parameters(*testcases):
"""A decorator for creating parameterized tests.
See the module docstring for a usage example.
Args:
*testcases: Parameters for the decorated method, either a single
iterable, or a list of tuples/dicts/objects (for tests with only one
argument).
Raises:
NoTestsError: Raised when the decorator generates no tests.
Returns:
A test generator to be handled by TestGeneratorMetaclass.
"""
return _parameter_decorator(_ARGUMENT_REPR, testcases)
def named_parameters(*testcases):
"""A decorator for creating parameterized tests.
See the module docstring for a usage example. For every parameter tuple
passed, the first element of the tuple should be a string and will be appended
to the name of the test method. Each parameter dict passed must have a value
for the key "testcase_name", the string representation of that value will be
appended to the name of the test method.
Args:
*testcases: Parameters for the decorated method, either a single iterable,
or a list of tuples or dicts.
Raises:
NoTestsError: Raised when the decorator generates no tests.
Returns:
A test generator to be handled by TestGeneratorMetaclass.
"""
return _parameter_decorator(_NAMED, testcases)
def product(*kwargs_seqs, **testgrid):
"""A decorator for running tests over cartesian product of parameters values.
See the module docstring for a usage example. The test will be run for every
possible combination of the parameters.
Args:
*kwargs_seqs: Each positional parameter is a sequence of keyword arg dicts;
every test case generated will include exactly one kwargs dict from each
positional parameter; these will then be merged to form an overall list
of arguments for the test case.
**testgrid: A mapping of parameter names and their possible values. Possible
values should given as either a list or a tuple.
Raises:
NoTestsError: Raised when the decorator generates no tests.
Returns:
A test generator to be handled by TestGeneratorMetaclass.
"""
for name, values in testgrid.items():
assert isinstance(values, (list, tuple)), (
'Values of {} must be given as list or tuple, found {}'.format(
name, type(values)))
prior_arg_names = set()
for kwargs_seq in kwargs_seqs:
assert ((isinstance(kwargs_seq, (list, tuple))) and
all(isinstance(kwargs, dict) for kwargs in kwargs_seq)), (
'Positional parameters must be a sequence of keyword arg'
'dicts, found {}'
.format(kwargs_seq))
if kwargs_seq:
arg_names = set(kwargs_seq[0])
assert all(set(kwargs) == arg_names for kwargs in kwargs_seq), (
'Keyword argument dicts within a single parameter must all have the '
'same keys, found {}'.format(kwargs_seq))
assert not (arg_names & prior_arg_names), (
'Keyword argument dict sequences must all have distinct argument '
'names, found duplicate(s) {}'
.format(sorted(arg_names & prior_arg_names)))
prior_arg_names |= arg_names
assert not (prior_arg_names & set(testgrid)), (
'Arguments supplied in kwargs dicts in positional parameters must not '
'overlap with arguments supplied as named parameters; found duplicate '
'argument(s) {}'.format(sorted(prior_arg_names & set(testgrid))))
# Convert testgrid into a sequence of sequences of kwargs dicts and combine
# with the positional parameters.
# So foo=[1,2], bar=[3,4] --> [[{foo: 1}, {foo: 2}], [{bar: 3, bar: 4}]]
testgrid = (tuple({k: v} for v in vs) for k, vs in testgrid.items())
testgrid = tuple(kwargs_seqs) + tuple(testgrid)
# Create all possible combinations of parameters as a cartesian product
# of parameter values.
testcases = [
dict(itertools.chain.from_iterable(case.items()
for case in cases))
for cases in itertools.product(*testgrid)
]
return _parameter_decorator(_ARGUMENT_REPR, testcases)
class TestGeneratorMetaclass(type):
"""Metaclass for adding tests generated by parameterized decorators."""
def __new__(cls, class_name, bases, dct):
# NOTE: _test_params_repr is private to parameterized.TestCase and it's
# metaclass; do not use it outside of those classes.
test_params_reprs = dct.setdefault('_test_params_reprs', {})
for name, obj in dct.copy().items():
if (name.startswith(unittest.TestLoader.testMethodPrefix) and
_non_string_or_bytes_iterable(obj)):
# NOTE: `obj` might not be a _ParameterizedTestIter in two cases:
# 1. a class-level iterable named test* that isn't a test, such as
# a list of something. Such attributes get deleted from the class.
#
# 2. If a decorator is applied to the parameterized test, e.g.
# @morestuff
# @parameterized.parameters(...)
# def test_foo(...): ...
#
# This is OK so long as the underlying parameterized function state
# is forwarded (e.g. using functool.wraps() and **without**
# accessing explicitly accessing the internal attributes.
if isinstance(obj, _ParameterizedTestIter):
# Update the original test method name so it's more accurate.
# The mismatch might happen when another decorator is used inside
# the parameterized decrators, and the inner decorator doesn't
# preserve its __name__.
obj._original_name = name
iterator = iter(obj)
dct.pop(name)
_update_class_dict_for_param_test_case(
class_name, dct, test_params_reprs, name, iterator)
# If the base class is a subclass of parameterized.TestCase, inherit its
# _test_params_reprs too.
for base in bases:
# Check if the base has _test_params_reprs first, then check if it's a
# subclass of parameterized.TestCase. Otherwise when this is called for
# the parameterized.TestCase definition itself, this raises because
# itself is not defined yet. This works as long as absltest.TestCase does
# not define _test_params_reprs.
base_test_params_reprs = getattr(base, '_test_params_reprs', None)
if base_test_params_reprs and issubclass(base, TestCase):
for test_method, test_method_id in base_test_params_reprs.items():
# test_method may both exists in base and this class.
# This class's method overrides base class's.
# That's why it should only inherit it if it does not exist.
test_params_reprs.setdefault(test_method, test_method_id)
return type.__new__(cls, class_name, bases, dct)
def _update_class_dict_for_param_test_case(
test_class_name, dct, test_params_reprs, name, iterator):
"""Adds individual test cases to a dictionary.
Args:
test_class_name: The name of the class tests are added to.
dct: The target dictionary.
test_params_reprs: The dictionary for mapping names to test IDs.
name: The original name of the test case.
iterator: The iterator generating the individual test cases.
Raises:
DuplicateTestNameError: Raised when a test name occurs multiple times.
RuntimeError: If non-parameterized functions are generated.
"""
for idx, func in enumerate(iterator):
assert callable(func), 'Test generators must yield callables, got %r' % (
func,)
if not (getattr(func, '__x_use_name__', None) or
getattr(func, '__x_params_repr__', None)):
raise RuntimeError(
'{}.{} generated a test function without using the parameterized '
'decorators. Only tests generated using the decorators are '
'supported.'.format(test_class_name, name))
if getattr(func, '__x_use_name__', False):
original_name = func.__name__
new_name = original_name
else:
original_name = name
new_name = '%s%d' % (original_name, idx)
if new_name in dct:
raise DuplicateTestNameError(test_class_name, new_name, original_name)
dct[new_name] = func
test_params_reprs[new_name] = getattr(func, '__x_params_repr__', '')
class TestCase(absltest.TestCase, metaclass=TestGeneratorMetaclass):
"""Base class for test cases using the parameters decorator."""
# visibility: private; do not call outside this class.
def _get_params_repr(self):
return self._test_params_reprs.get(self._testMethodName, '')
def __str__(self):
params_repr = self._get_params_repr()
if params_repr:
params_repr = ' ' + params_repr
return '{}{} ({})'.format(
self._testMethodName, params_repr,
unittest.util.strclass(self.__class__))
def id(self):
"""Returns the descriptive ID of the test.
This is used internally by the unittesting framework to get a name
for the test to be used in reports.
Returns:
The test id.
"""
base = super(TestCase, self).id()
params_repr = self._get_params_repr()
if params_repr:
# We include the params in the id so that, when reported in the
# test.xml file, the value is more informative than just "test_foo0".
# Use a space to separate them so that it's copy/paste friendly and
# easy to identify the actual test id.
return '{} {}'.format(base, params_repr)
else:
return base
# This function is kept CamelCase because it's used as a class's base class.
def CoopTestCase(other_base_class): # pylint: disable=invalid-name
"""Returns a new base class with a cooperative metaclass base.
This enables the TestCase to be used in combination
with other base classes that have custom metaclasses, such as
``mox.MoxTestBase``.
Only works with metaclasses that do not override ``type.__new__``.
Example::
from absl.testing import parameterized
class ExampleTest(parameterized.CoopTestCase(OtherTestCase)):
...
Args:
other_base_class: (class) A test case base class.
Returns:
A new class object.
"""
# If the other base class has a metaclass of 'type' then trying to combine
# the metaclasses will result in an MRO error. So simply combine them and
# return.
if type(other_base_class) == type: # pylint: disable=unidiomatic-typecheck
warnings.warn(
'CoopTestCase is only necessary when combining with a class that uses'
' a metaclass. Use multiple inheritance like this instead: class'
f' ExampleTest(paramaterized.TestCase, {other_base_class.__name__}):',
stacklevel=2,
)
class CoopTestCaseBase(other_base_class, TestCase):
pass
return CoopTestCaseBase
else:
class CoopMetaclass(type(other_base_class), TestGeneratorMetaclass): # pylint: disable=unused-variable
pass
class CoopTestCaseBase(other_base_class, TestCase, metaclass=CoopMetaclass):
pass
return CoopTestCaseBase
abseil-py-2.1.0/absl/testing/tests/ 0000775 0000000 0000000 00000000000 14551576331 0017151 5 ustar 00root root 0000000 0000000 abseil-py-2.1.0/absl/testing/tests/__init__.py 0000664 0000000 0000000 00000001110 14551576331 0021253 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
abseil-py-2.1.0/absl/testing/tests/absltest_env.py 0000664 0000000 0000000 00000001636 14551576331 0022222 0 ustar 00root root 0000000 0000000 """Helper library to get environment variables for absltest helper binaries."""
import os
_INHERITED_ENV_KEYS = frozenset({
# This is needed to correctly use the Python interpreter determined by
# bazel.
'PATH',
# This is used by the random module on Windows to locate crypto
# libraries.
'SYSTEMROOT',
})
def inherited_env():
"""Returns the environment variables that should be inherited from parent.
Reason why using an explicit list of environment variables instead of
inheriting all from parent: the absltest module itself interprets a list of
environment variables set by bazel, e.g. XML_OUTPUT_FILE,
TESTBRIDGE_TEST_ONLY. While testing absltest's own behavior, we should
remove them when invoking the helper subprocess. Using an explicit list is
safer.
"""
env = {}
for key in _INHERITED_ENV_KEYS:
if key in os.environ:
env[key] = os.environ[key]
return env
abseil-py-2.1.0/absl/testing/tests/absltest_fail_fast_test.py 0000664 0000000 0000000 00000006633 14551576331 0024423 0 ustar 00root root 0000000 0000000 # Copyright 2020 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for test fail fast protocol."""
import subprocess
from absl import logging
from absl.testing import _bazelize_command
from absl.testing import absltest
from absl.testing import parameterized
from absl.testing.tests import absltest_env
@parameterized.named_parameters(
('use_app_run', True),
('no_argv', False),
)
class TestFailFastTest(parameterized.TestCase):
"""Integration tests: Runs a test binary with fail fast.
This is done by setting the fail fast environment variable
"""
def setUp(self):
super().setUp()
self._test_name = 'absl/testing/tests/absltest_fail_fast_test_helper'
def _run_fail_fast(self, fail_fast, use_app_run):
"""Runs the py_test binary in a subprocess.
Args:
fail_fast: string, the fail fast value.
use_app_run: bool, whether the test helper should call
`absltest.main(argv=)` inside `app.run`.
Returns:
(stdout, exit_code) tuple of (string, int).
"""
env = absltest_env.inherited_env()
if fail_fast is not None:
env['TESTBRIDGE_TEST_RUNNER_FAIL_FAST'] = fail_fast
env['USE_APP_RUN'] = '1' if use_app_run else '0'
proc = subprocess.Popen(
args=[_bazelize_command.get_executable_path(self._test_name)],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True)
stdout = proc.communicate()[0]
logging.info('output: %s', stdout)
return stdout, proc.wait()
def test_no_fail_fast(self, use_app_run):
out, exit_code = self._run_fail_fast(None, use_app_run)
self.assertEqual(1, exit_code)
self.assertIn('class A test A', out)
self.assertIn('class A test B', out)
self.assertIn('class A test C', out)
self.assertIn('class A test D', out)
self.assertIn('class A test E', out)
def test_empty_fail_fast(self, use_app_run):
out, exit_code = self._run_fail_fast('', use_app_run)
self.assertEqual(1, exit_code)
self.assertIn('class A test A', out)
self.assertIn('class A test B', out)
self.assertIn('class A test C', out)
self.assertIn('class A test D', out)
self.assertIn('class A test E', out)
def test_fail_fast_1(self, use_app_run):
out, exit_code = self._run_fail_fast('1', use_app_run)
self.assertEqual(1, exit_code)
self.assertIn('class A test A', out)
self.assertIn('class A test B', out)
self.assertIn('class A test C', out)
self.assertNotIn('class A test D', out)
self.assertNotIn('class A test E', out)
def test_fail_fast_0(self, use_app_run):
out, exit_code = self._run_fail_fast('0', use_app_run)
self.assertEqual(1, exit_code)
self.assertIn('class A test A', out)
self.assertIn('class A test B', out)
self.assertIn('class A test C', out)
self.assertIn('class A test D', out)
self.assertIn('class A test E', out)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_fail_fast_test_helper.py 0000664 0000000 0000000 00000002444 14551576331 0025756 0 ustar 00root root 0000000 0000000 # Copyright 2020 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A helper test program for absltest_fail_fast_test."""
import os
import sys
from absl import app
from absl.testing import absltest
class ClassA(absltest.TestCase):
"""Helper test case A for absltest_fail_fast_test."""
def testA(self):
sys.stderr.write('\nclass A test A\n')
def testB(self):
sys.stderr.write('\nclass A test B\n')
def testC(self):
sys.stderr.write('\nclass A test C\n')
self.fail('Force failure')
def testD(self):
sys.stderr.write('\nclass A test D\n')
def testE(self):
sys.stderr.write('\nclass A test E\n')
def main(argv):
absltest.main(argv=argv)
if __name__ == '__main__':
if os.environ['USE_APP_RUN'] == '1':
app.run(main)
else:
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_filtering_test.py 0000664 0000000 0000000 00000016671 14551576331 0024461 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for test filtering protocol."""
import subprocess
import sys
from absl import logging
from absl.testing import _bazelize_command
from absl.testing import absltest
from absl.testing import parameterized
from absl.testing.tests import absltest_env
@parameterized.named_parameters(
('as_env_variable_use_app_run', True, True),
('as_env_variable_no_argv', True, False),
('as_commandline_args_use_app_run', False, True),
('as_commandline_args_no_argv', False, False),
)
class TestFilteringTest(absltest.TestCase):
"""Integration tests: Runs a test binary with filtering.
This is done by either setting the filtering environment variable, or passing
the filters as command line arguments.
"""
def setUp(self):
super().setUp()
self._test_name = 'absl/testing/tests/absltest_filtering_test_helper'
def _run_filtered(self, test_filter, use_env_variable, use_app_run):
"""Runs the py_test binary in a subprocess.
Args:
test_filter: string, the filter argument to use.
use_env_variable: bool, pass the test filter as environment variable if
True, otherwise pass as command line arguments.
use_app_run: bool, whether the test helper should call
`absltest.main(argv=)` inside `app.run`.
Returns:
(stdout, exit_code) tuple of (string, int).
"""
env = absltest_env.inherited_env()
env['USE_APP_RUN'] = '1' if use_app_run else '0'
additional_args = []
if test_filter is not None:
if use_env_variable:
env['TESTBRIDGE_TEST_ONLY'] = test_filter
elif test_filter:
if sys.version_info[:2] >= (3, 7):
# The -k flags are passed as positional arguments to absl.flags.
additional_args.append('--')
additional_args.extend(['-k=' + f for f in test_filter.split(' ')])
else:
additional_args.extend(test_filter.split(' '))
proc = subprocess.Popen(
args=([_bazelize_command.get_executable_path(self._test_name)] +
additional_args),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True)
stdout = proc.communicate()[0]
logging.info('output: %s', stdout)
return stdout, proc.wait()
def test_no_filter(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered(None, use_env_variable, use_app_run)
self.assertEqual(1, exit_code)
self.assertIn('class B test E', out)
def test_empty_filter(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered('', use_env_variable, use_app_run)
self.assertEqual(1, exit_code)
self.assertIn('class B test E', out)
def test_class_filter(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered('ClassA', use_env_variable, use_app_run)
self.assertEqual(0, exit_code)
self.assertNotIn('class B', out)
def test_method_filter(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered('ClassB.testA', use_env_variable,
use_app_run)
self.assertEqual(0, exit_code)
self.assertNotIn('class A', out)
self.assertNotIn('class B test B', out)
out, exit_code = self._run_filtered('ClassB.testE', use_env_variable,
use_app_run)
self.assertEqual(1, exit_code)
self.assertNotIn('class A', out)
def test_multiple_class_and_method_filter(self, use_env_variable,
use_app_run):
out, exit_code = self._run_filtered(
'ClassA.testA ClassA.testB ClassB.testC', use_env_variable, use_app_run)
self.assertEqual(0, exit_code)
self.assertIn('class A test A', out)
self.assertIn('class A test B', out)
self.assertNotIn('class A test C', out)
self.assertIn('class B test C', out)
self.assertNotIn('class B test A', out)
@absltest.skipIf(
sys.version_info[:2] < (3, 7),
'Only Python 3.7+ does glob and substring matching.')
def test_substring(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered(
'testA', use_env_variable, use_app_run)
self.assertEqual(0, exit_code)
self.assertIn('Ran 2 tests', out)
self.assertIn('ClassA.testA', out)
self.assertIn('ClassB.testA', out)
@absltest.skipIf(
sys.version_info[:2] < (3, 7),
'Only Python 3.7+ does glob and substring matching.')
def test_glob_pattern(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered(
'__main__.Class*.testA', use_env_variable, use_app_run)
self.assertEqual(0, exit_code)
self.assertIn('Ran 2 tests', out)
self.assertIn('ClassA.testA', out)
self.assertIn('ClassB.testA', out)
@absltest.skipIf(
sys.version_info[:2] >= (3, 7),
"Python 3.7+ uses unittest's -k flag and doesn't fail if no tests match.")
def test_not_found_filters_py36(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered('NotExistedClass.not_existed_method',
use_env_variable, use_app_run)
self.assertEqual(1, exit_code)
self.assertIn("has no attribute 'NotExistedClass'", out)
@absltest.skipIf(
sys.version_info[:2] < (3, 7),
'Python 3.6 passes the filter as positional arguments and fails if no '
'tests match.'
)
def test_not_found_filters_py37(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered('NotExistedClass.not_existed_method',
use_env_variable, use_app_run)
if not use_env_variable and sys.version_info[:2] >= (3, 12):
# When test filter is requested with the unittest `-k` flag, absltest
# respect unittest to fail when no tests run on Python 3.12+.
self.assertEqual(5, exit_code)
else:
self.assertEqual(0, exit_code)
self.assertIn('Ran 0 tests', out)
@absltest.skipIf(
sys.version_info[:2] < (3, 7),
'Python 3.6 passes the filter as positional arguments and matches by name'
)
def test_parameterized_unnamed(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered('ParameterizedTest.test_unnamed',
use_env_variable, use_app_run)
self.assertEqual(0, exit_code)
self.assertIn('Ran 2 tests', out)
self.assertIn('parameterized unnamed 1', out)
self.assertIn('parameterized unnamed 2', out)
@absltest.skipIf(
sys.version_info[:2] < (3, 7),
'Python 3.6 passes the filter as positional arguments and matches by name'
)
def test_parameterized_named(self, use_env_variable, use_app_run):
out, exit_code = self._run_filtered('ParameterizedTest.test_named',
use_env_variable, use_app_run)
self.assertEqual(0, exit_code)
self.assertIn('Ran 2 tests', out)
self.assertIn('parameterized named 1', out)
self.assertIn('parameterized named 2', out)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_filtering_test_helper.py 0000664 0000000 0000000 00000004003 14551576331 0026002 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A helper test program for absltest_filtering_test."""
import os
import sys
from absl import app
from absl.testing import absltest
from absl.testing import parameterized
class ClassA(absltest.TestCase):
"""Helper test case A for absltest_filtering_test."""
def testA(self):
sys.stderr.write('\nclass A test A\n')
def testB(self):
sys.stderr.write('\nclass A test B\n')
def testC(self):
sys.stderr.write('\nclass A test C\n')
class ClassB(absltest.TestCase):
"""Helper test case B for absltest_filtering_test."""
def testA(self):
sys.stderr.write('\nclass B test A\n')
def testB(self):
sys.stderr.write('\nclass B test B\n')
def testC(self):
sys.stderr.write('\nclass B test C\n')
def testD(self):
sys.stderr.write('\nclass B test D\n')
def testE(self):
sys.stderr.write('\nclass B test E\n')
self.fail('Force failure')
class ParameterizedTest(parameterized.TestCase):
"""Helper parameterized test case for absltest_filtering_test."""
@parameterized.parameters([1, 2])
def test_unnamed(self, value):
sys.stderr.write('\nparameterized unnamed %s' % value)
@parameterized.named_parameters(
('test1', 1),
('test2', 2),
)
def test_named(self, value):
sys.stderr.write('\nparameterized named %s' % value)
def main(argv):
absltest.main(argv=argv)
if __name__ == '__main__':
if os.environ['USE_APP_RUN'] == '1':
app.run(main)
else:
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_randomization_test.py 0000664 0000000 0000000 00000011472 14551576331 0025346 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for test randomization."""
import random
import subprocess
from absl import flags
from absl.testing import _bazelize_command
from absl.testing import absltest
from absl.testing import parameterized
from absl.testing.tests import absltest_env
FLAGS = flags.FLAGS
class TestOrderRandomizationTest(parameterized.TestCase):
"""Integration tests: Runs a py_test binary with randomization.
This is done by setting flags and environment variables.
"""
def setUp(self):
super(TestOrderRandomizationTest, self).setUp()
self._test_name = 'absl/testing/tests/absltest_randomization_testcase'
def _run_test(self, extra_argv, extra_env):
"""Runs the py_test binary in a subprocess, with the given args or env.
Args:
extra_argv: extra args to pass to the test
extra_env: extra env vars to set when running the test
Returns:
(stdout, test_cases, exit_code) tuple of (str, list of strs, int).
"""
env = absltest_env.inherited_env()
# If *this* test is being run with this flag, we don't want to
# automatically set it for all tests we run.
env.pop('TEST_RANDOMIZE_ORDERING_SEED', '')
if extra_env is not None:
env.update(extra_env)
command = (
[_bazelize_command.get_executable_path(self._test_name)] + extra_argv)
proc = subprocess.Popen(
args=command,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True)
stdout, _ = proc.communicate()
test_lines = [l for l in stdout.splitlines() if l.startswith('class ')]
return stdout, test_lines, proc.wait()
def test_no_args(self):
output, tests, exit_code = self._run_test([], None)
self.assertEqual(0, exit_code, msg='command output: ' + output)
self.assertNotIn('Randomizing test order with seed:', output)
cases = ['class A test ' + t for t in ('A', 'B', 'C')]
self.assertEqual(cases, tests)
@parameterized.parameters(
{
'argv': ['--test_randomize_ordering_seed=random'],
'env': None,
},
{
'argv': [],
'env': {
'TEST_RANDOMIZE_ORDERING_SEED': 'random',
},
},)
def test_simple_randomization(self, argv, env):
output, tests, exit_code = self._run_test(argv, env)
self.assertEqual(0, exit_code, msg='command output: ' + output)
self.assertIn('Randomizing test order with seed: ', output)
cases = ['class A test ' + t for t in ('A', 'B', 'C')]
# This may come back in any order; we just know it'll be the same
# set of elements.
self.assertSameElements(cases, tests)
@parameterized.parameters(
{
'argv': ['--test_randomize_ordering_seed=1'],
'env': None,
},
{
'argv': [],
'env': {
'TEST_RANDOMIZE_ORDERING_SEED': '1'
},
},
{
'argv': [],
'env': {
'LATE_SET_TEST_RANDOMIZE_ORDERING_SEED': '1'
},
},
)
def test_fixed_seed(self, argv, env):
output, tests, exit_code = self._run_test(argv, env)
self.assertEqual(0, exit_code, msg='command output: ' + output)
self.assertIn('Randomizing test order with seed: 1', output)
# Even though we know the seed, we need to shuffle the tests here, since
# this behaves differently in Python2 vs Python3.
shuffled_cases = ['A', 'B', 'C']
random.Random(1).shuffle(shuffled_cases)
cases = ['class A test ' + t for t in shuffled_cases]
# We know what order this will come back for the random seed we've
# specified.
self.assertEqual(cases, tests)
@parameterized.parameters(
{
'argv': ['--test_randomize_ordering_seed=0'],
'env': {
'TEST_RANDOMIZE_ORDERING_SEED': 'random'
},
},
{
'argv': [],
'env': {
'TEST_RANDOMIZE_ORDERING_SEED': '0'
},
},)
def test_disabling_randomization(self, argv, env):
output, tests, exit_code = self._run_test(argv, env)
self.assertEqual(0, exit_code, msg='command output: ' + output)
self.assertNotIn('Randomizing test order with seed:', output)
cases = ['class A test ' + t for t in ('A', 'B', 'C')]
self.assertEqual(cases, tests)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_randomization_testcase.py 0000664 0000000 0000000 00000002352 14551576331 0026177 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stub tests, only for use in absltest_randomization_test.py."""
import os
import sys
from absl.testing import absltest
# This stanza exercises setting $TEST_RANDOMIZE_ORDERING_SEED *after* importing
# the absltest library.
if os.environ.get('LATE_SET_TEST_RANDOMIZE_ORDERING_SEED', ''):
os.environ['TEST_RANDOMIZE_ORDERING_SEED'] = os.environ[
'LATE_SET_TEST_RANDOMIZE_ORDERING_SEED']
class ClassA(absltest.TestCase):
def test_a(self):
sys.stderr.write('\nclass A test A\n')
def test_b(self):
sys.stderr.write('\nclass A test B\n')
def test_c(self):
sys.stderr.write('\nclass A test C\n')
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_sharding_test.py 0000664 0000000 0000000 00000014770 14551576331 0024273 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for test sharding protocol."""
import os
import subprocess
import sys
from absl.testing import _bazelize_command
from absl.testing import absltest
from absl.testing import parameterized
from absl.testing.tests import absltest_env
NUM_TEST_METHODS = 8 # Hard-coded, based on absltest_sharding_test_helper.py
class TestShardingTest(parameterized.TestCase):
"""Integration tests: Runs a test binary with sharding.
This is done by setting the sharding environment variables.
"""
def setUp(self):
super().setUp()
self._shard_file = None
def tearDown(self):
super().tearDown()
if self._shard_file is not None and os.path.exists(self._shard_file):
os.unlink(self._shard_file)
def _run_sharded(
self,
total_shards,
shard_index,
shard_file=None,
additional_env=None,
helper_name='absltest_sharding_test_helper',
):
"""Runs the py_test binary in a subprocess.
Args:
total_shards: int, the total number of shards.
shard_index: int, the shard index.
shard_file: string, if not 'None', the path to the shard file. This method
asserts it is properly created.
additional_env: Additional environment variables to be set for the py_test
binary.
helper_name: The name of the helper binary.
Returns:
(stdout, exit_code) tuple of (string, int).
"""
env = absltest_env.inherited_env()
if additional_env:
env.update(additional_env)
env.update({
'TEST_TOTAL_SHARDS': str(total_shards),
'TEST_SHARD_INDEX': str(shard_index)
})
if shard_file:
self._shard_file = shard_file
env['TEST_SHARD_STATUS_FILE'] = shard_file
if os.path.exists(shard_file):
os.unlink(shard_file)
helper = 'absl/testing/tests/' + helper_name
proc = subprocess.Popen(
args=[_bazelize_command.get_executable_path(helper)],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True,
)
stdout = proc.communicate()[0]
if shard_file:
self.assertTrue(os.path.exists(shard_file))
return (stdout, proc.wait())
def _assert_sharding_correctness(self, total_shards):
"""Assert the primary correctness and performance of sharding.
1. Completeness (all methods are run)
2. Partition (each method run at most once)
3. Balance (for performance)
Args:
total_shards: int, total number of shards.
"""
outerr_by_shard = [] # A list of lists of strings
combined_outerr = [] # A list of strings
exit_code_by_shard = [] # A list of ints
for i in range(total_shards):
(out, exit_code) = self._run_sharded(total_shards, i)
method_list = [x for x in out.split('\n') if x.startswith('class')]
outerr_by_shard.append(method_list)
combined_outerr.extend(method_list)
exit_code_by_shard.append(exit_code)
self.assertLen([x for x in exit_code_by_shard if x != 0], 1,
'Expected exactly one failure')
# Test completeness and partition properties.
self.assertLen(combined_outerr, NUM_TEST_METHODS,
'Partition requirement not met')
self.assertLen(set(combined_outerr), NUM_TEST_METHODS,
'Completeness requirement not met')
# Test balance:
for i in range(len(outerr_by_shard)):
self.assertGreaterEqual(len(outerr_by_shard[i]),
(NUM_TEST_METHODS / total_shards) - 1,
'Shard %d of %d out of balance' %
(i, len(outerr_by_shard)))
def test_shard_file(self):
self._run_sharded(3, 1, os.path.join(
absltest.TEST_TMPDIR.value, 'shard_file'))
def test_zero_shards(self):
out, exit_code = self._run_sharded(0, 0)
self.assertEqual(1, exit_code)
self.assertGreaterEqual(out.find('Bad sharding values. index=0, total=0'),
0, 'Bad output: %s' % (out))
def test_with_four_shards(self):
self._assert_sharding_correctness(4)
def test_with_one_shard(self):
self._assert_sharding_correctness(1)
def test_with_ten_shards(self):
shards = 10
# This test relies on the shard count to be greater than the number of
# tests, to ensure that the non-zero shards won't fail even if no tests ran
# on Python 3.12+.
self.assertGreater(shards, NUM_TEST_METHODS)
self._assert_sharding_correctness(shards)
def test_sharding_with_randomization(self):
# If we're both sharding *and* randomizing, we need to confirm that we
# randomize within the shard; we use two seeds to confirm we're seeing the
# same tests (sharding is consistent) in a different order.
tests_seen = []
for seed in ('7', '17'):
out, exit_code = self._run_sharded(
2, 0, additional_env={'TEST_RANDOMIZE_ORDERING_SEED': seed})
self.assertEqual(0, exit_code)
tests_seen.append([x for x in out.splitlines() if x.startswith('class')])
first_tests, second_tests = tests_seen # pylint: disable=unbalanced-tuple-unpacking
self.assertEqual(set(first_tests), set(second_tests))
self.assertNotEqual(first_tests, second_tests)
@parameterized.named_parameters(
('total_1_index_0', 1, 0, None),
('total_2_index_0', 2, 0, None),
# The 2nd shard (index=1) should not fail.
('total_2_index_1', 2, 1, 0),
)
def test_no_tests_ran(
self, total_shards, shard_index, override_expected_exit_code
):
if override_expected_exit_code is not None:
expected_exit_code = override_expected_exit_code
elif sys.version_info[:2] >= (3, 12):
expected_exit_code = 5
else:
expected_exit_code = 0
out, exit_code = self._run_sharded(
total_shards,
shard_index,
helper_name='absltest_sharding_test_helper_no_tests',
)
self.assertEqual(
expected_exit_code,
exit_code,
'Unexpected exit code, output:\n{}'.format(out),
)
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_sharding_test_helper.py 0000664 0000000 0000000 00000002652 14551576331 0025626 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A helper test program for absltest_sharding_test."""
import sys
from absl.testing import absltest
class ClassA(absltest.TestCase):
"""Helper test case A for absltest_sharding_test."""
def testA(self):
sys.stderr.write('\nclass A test A\n')
def testB(self):
sys.stderr.write('\nclass A test B\n')
def testC(self):
sys.stderr.write('\nclass A test C\n')
class ClassB(absltest.TestCase):
"""Helper test case B for absltest_sharding_test."""
def testA(self):
sys.stderr.write('\nclass B test A\n')
def testB(self):
sys.stderr.write('\nclass B test B\n')
def testC(self):
sys.stderr.write('\nclass B test C\n')
def testD(self):
sys.stderr.write('\nclass B test D\n')
def testE(self):
sys.stderr.write('\nclass B test E\n')
self.fail('Force failure')
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_sharding_test_helper_no_tests.py 0000664 0000000 0000000 00000001417 14551576331 0027542 0 ustar 00root root 0000000 0000000 # Copyright 2023 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A helper test program with no tests ran for absltest_sharding_test."""
from absl.testing import absltest
class MyTest(absltest.TestCase):
pass
if __name__ == "__main__":
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_test.py 0000664 0000000 0000000 00000255121 14551576331 0022411 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for absltest."""
import collections
import contextlib
import dataclasses
import io
import os
import pathlib
import re
import stat
import string
import subprocess
import sys
import tempfile
import textwrap
from typing import Optional
import unittest
from absl.testing import _bazelize_command
from absl.testing import absltest
from absl.testing import parameterized
from absl.testing.tests import absltest_env
class BaseTestCase(absltest.TestCase):
def _get_helper_exec_path(self, helper_name):
helper = 'absl/testing/tests/' + helper_name
return _bazelize_command.get_executable_path(helper)
def run_helper(
self,
test_id,
args,
env_overrides,
expect_success,
helper_name=None,
):
env = absltest_env.inherited_env()
for key, value in env_overrides.items():
if value is None:
if key in env:
del env[key]
else:
env[key] = value
if helper_name is None:
helper_name = 'absltest_test_helper'
command = [self._get_helper_exec_path(helper_name)]
if test_id is not None:
command.append('--test_id={}'.format(test_id))
command.extend(args)
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env,
universal_newlines=True)
stdout, stderr = process.communicate()
if expect_success:
self.assertEqual(
0,
process.returncode,
'Expected success, but failed with exit code {},'
' stdout:\n{}\nstderr:\n{}\n'.format(
process.returncode, stdout, stderr
),
)
else:
self.assertGreater(
process.returncode,
0,
'Expected failure, but succeeded with '
'stdout:\n{}\nstderr:\n{}\n'.format(stdout, stderr),
)
return stdout, stderr, process.returncode
class TestCaseTest(BaseTestCase):
longMessage = True
def run_helper(
self, test_id, args, env_overrides, expect_success, helper_name=None
):
return super(TestCaseTest, self).run_helper(
test_id,
args + ['HelperTest'],
env_overrides,
expect_success,
helper_name,
)
def test_flags_no_env_var_no_flags(self):
self.run_helper(
1,
[],
{'TEST_RANDOM_SEED': None,
'TEST_SRCDIR': None,
'TEST_TMPDIR': None,
},
expect_success=True)
def test_flags_env_var_no_flags(self):
tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
srcdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
self.run_helper(
2,
[],
{'TEST_RANDOM_SEED': '321',
'TEST_SRCDIR': srcdir,
'TEST_TMPDIR': tmpdir,
'ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR': srcdir,
'ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR': tmpdir,
},
expect_success=True)
def test_flags_no_env_var_flags(self):
tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
srcdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
self.run_helper(
3,
['--test_random_seed=123', '--test_srcdir={}'.format(srcdir),
'--test_tmpdir={}'.format(tmpdir)],
{'TEST_RANDOM_SEED': None,
'TEST_SRCDIR': None,
'TEST_TMPDIR': None,
'ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR': srcdir,
'ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR': tmpdir,
},
expect_success=True)
def test_flags_env_var_flags(self):
tmpdir_from_flag = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
srcdir_from_flag = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
tmpdir_from_env_var = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
srcdir_from_env_var = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
self.run_helper(
4,
['--test_random_seed=221', '--test_srcdir={}'.format(srcdir_from_flag),
'--test_tmpdir={}'.format(tmpdir_from_flag)],
{'TEST_RANDOM_SEED': '123',
'TEST_SRCDIR': srcdir_from_env_var,
'TEST_TMPDIR': tmpdir_from_env_var,
'ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR': srcdir_from_flag,
'ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR': tmpdir_from_flag,
},
expect_success=True)
def test_xml_output_file_from_xml_output_file_env(self):
xml_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
xml_output_file_env = os.path.join(xml_dir, 'xml_output_file.xml')
random_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
self.run_helper(
6,
[],
{'XML_OUTPUT_FILE': xml_output_file_env,
'RUNNING_UNDER_TEST_DAEMON': '1',
'TEST_XMLOUTPUTDIR': random_dir,
'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': xml_output_file_env,
},
expect_success=True)
def test_xml_output_file_from_daemon(self):
tmpdir = os.path.join(tempfile.mkdtemp(
dir=absltest.TEST_TMPDIR.value), 'sub_dir')
random_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
self.run_helper(
6,
['--test_tmpdir', tmpdir],
{'XML_OUTPUT_FILE': None,
'RUNNING_UNDER_TEST_DAEMON': '1',
'TEST_XMLOUTPUTDIR': random_dir,
'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': os.path.join(
os.path.dirname(tmpdir), 'test_detail.xml'),
},
expect_success=True)
def test_xml_output_file_from_test_xmloutputdir_env(self):
xml_output_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
expected_xml_file = 'absltest_test_helper.xml'
self.run_helper(
6,
[],
{'XML_OUTPUT_FILE': None,
'RUNNING_UNDER_TEST_DAEMON': None,
'TEST_XMLOUTPUTDIR': xml_output_dir,
'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': os.path.join(
xml_output_dir, expected_xml_file),
},
expect_success=True)
def test_xml_output_file_from_flag(self):
random_dir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
flag_file = os.path.join(
tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value), 'output.xml')
self.run_helper(
6,
['--xml_output_file', flag_file],
{'XML_OUTPUT_FILE': os.path.join(random_dir, 'output.xml'),
'RUNNING_UNDER_TEST_DAEMON': '1',
'TEST_XMLOUTPUTDIR': random_dir,
'ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE': flag_file,
},
expect_success=True)
def test_app_run(self):
stdout, _, _ = self.run_helper(
7,
['--name=cat', '--name=dog'],
{'ABSLTEST_TEST_HELPER_USE_APP_RUN': '1'},
expect_success=True,
)
self.assertIn('Names in main() are: cat dog', stdout)
self.assertIn('Names in test_name_flag() are: cat dog', stdout)
def test_assert_in(self):
animals = {'monkey': 'banana', 'cow': 'grass', 'seal': 'fish'}
self.assertIn('a', 'abc')
self.assertIn(2, [1, 2, 3])
self.assertIn('monkey', animals)
self.assertNotIn('d', 'abc')
self.assertNotIn(0, [1, 2, 3])
self.assertNotIn('otter', animals)
self.assertRaises(AssertionError, self.assertIn, 'x', 'abc')
self.assertRaises(AssertionError, self.assertIn, 4, [1, 2, 3])
self.assertRaises(AssertionError, self.assertIn, 'elephant', animals)
self.assertRaises(AssertionError, self.assertNotIn, 'c', 'abc')
self.assertRaises(AssertionError, self.assertNotIn, 1, [1, 2, 3])
self.assertRaises(AssertionError, self.assertNotIn, 'cow', animals)
@absltest.expectedFailure
def test_expected_failure(self):
self.assertEqual(1, 2) # the expected failure
@absltest.expectedFailureIf(True, 'always true')
def test_expected_failure_if(self):
self.assertEqual(1, 2) # the expected failure
def test_expected_failure_success(self):
_, stderr, _ = self.run_helper(5, ['--', '-v'], {}, expect_success=False)
self.assertRegex(stderr, r'FAILED \(.*unexpected successes=1\)')
def test_assert_equal(self):
self.assertListEqual([], [])
self.assertTupleEqual((), ())
self.assertSequenceEqual([], ())
a = [0, 'a', []]
b = []
self.assertRaises(absltest.TestCase.failureException,
self.assertListEqual, a, b)
self.assertRaises(absltest.TestCase.failureException,
self.assertListEqual, tuple(a), tuple(b))
self.assertRaises(absltest.TestCase.failureException,
self.assertSequenceEqual, a, tuple(b))
b.extend(a)
self.assertListEqual(a, b)
self.assertTupleEqual(tuple(a), tuple(b))
self.assertSequenceEqual(a, tuple(b))
self.assertSequenceEqual(tuple(a), b)
self.assertRaises(AssertionError, self.assertListEqual, a, tuple(b))
self.assertRaises(AssertionError, self.assertTupleEqual, tuple(a), b)
self.assertRaises(AssertionError, self.assertListEqual, None, b)
self.assertRaises(AssertionError, self.assertTupleEqual, None, tuple(b))
self.assertRaises(AssertionError, self.assertSequenceEqual, None, tuple(b))
self.assertRaises(AssertionError, self.assertListEqual, 1, 1)
self.assertRaises(AssertionError, self.assertTupleEqual, 1, 1)
self.assertRaises(AssertionError, self.assertSequenceEqual, 1, 1)
self.assertSameElements([1, 2, 3], [3, 2, 1])
self.assertSameElements([1, 2] + [3] * 100, [1] * 100 + [2, 3])
self.assertSameElements(['foo', 'bar', 'baz'], ['bar', 'baz', 'foo'])
self.assertRaises(AssertionError, self.assertSameElements, [10], [10, 11])
self.assertRaises(AssertionError, self.assertSameElements, [10, 11], [10])
# Test that sequences of unhashable objects can be tested for sameness:
self.assertSameElements([[1, 2], [3, 4]], [[3, 4], [1, 2]])
self.assertRaises(AssertionError, self.assertSameElements, [[1]], [[2]])
def test_assert_items_equal_hotfix(self):
"""Confirm that http://bugs.python.org/issue14832 - b/10038517 is gone."""
for assert_items_method in (self.assertItemsEqual, self.assertCountEqual):
with self.assertRaises(self.failureException) as error_context:
assert_items_method([4], [2])
error_message = str(error_context.exception)
# Confirm that the bug is either no longer present in Python or that our
# assertItemsEqual patching version of the method in absltest.TestCase
# doesn't get used.
self.assertIn('First has 1, Second has 0: 4', error_message)
self.assertIn('First has 0, Second has 1: 2', error_message)
def test_assert_dict_equal(self):
self.assertDictEqual({}, {})
c = {'x': 1}
d = {}
self.assertRaises(absltest.TestCase.failureException,
self.assertDictEqual, c, d)
d.update(c)
self.assertDictEqual(c, d)
d['x'] = 0
self.assertRaises(absltest.TestCase.failureException,
self.assertDictEqual, c, d, 'These are unequal')
self.assertRaises(AssertionError, self.assertDictEqual, None, d)
self.assertRaises(AssertionError, self.assertDictEqual, [], d)
self.assertRaises(AssertionError, self.assertDictEqual, 1, 1)
try:
# Ensure we use equality as the sole measure of elements, not type, since
# that is consistent with dict equality.
self.assertDictEqual({1: 1.0, 2: 2}, {1: 1, 2: 3})
except AssertionError as e:
self.assertMultiLineEqual('{1: 1.0, 2: 2} != {1: 1, 2: 3}\n'
'repr() of differing entries:\n2: 2 != 3\n',
str(e))
try:
self.assertDictEqual({}, {'x': 1})
except AssertionError as e:
self.assertMultiLineEqual("{} != {'x': 1}\n"
"Unexpected, but present entries:\n'x': 1\n",
str(e))
else:
self.fail('Expecting AssertionError')
try:
self.assertDictEqual({}, {'x': 1}, 'a message')
except AssertionError as e:
self.assertIn('a message', str(e))
else:
self.fail('Expecting AssertionError')
expected = {'a': 1, 'b': 2, 'c': 3}
seen = {'a': 2, 'c': 3, 'd': 4}
try:
self.assertDictEqual(expected, seen)
except AssertionError as e:
self.assertMultiLineEqual("""\
{'a': 1, 'b': 2, 'c': 3} != {'a': 2, 'c': 3, 'd': 4}
Unexpected, but present entries:
'd': 4
repr() of differing entries:
'a': 1 != 2
Missing entries:
'b': 2
""", str(e))
else:
self.fail('Expecting AssertionError')
self.assertRaises(AssertionError, self.assertDictEqual, (1, 2), {})
self.assertRaises(AssertionError, self.assertDictEqual, {}, (1, 2))
# Ensure deterministic output of keys in dictionaries whose sort order
# doesn't match the lexical ordering of repr -- this is most Python objects,
# which are keyed by memory address.
class Obj(object):
def __init__(self, name):
self.name = name
def __repr__(self):
return self.name
try:
self.assertDictEqual(
{'a': Obj('A'), Obj('b'): Obj('B'), Obj('c'): Obj('C')},
{'a': Obj('A'), Obj('d'): Obj('D'), Obj('e'): Obj('E')})
except AssertionError as e:
# Do as best we can not to be misleading when objects have the same repr
# but aren't equal.
err_str = str(e)
self.assertStartsWith(err_str,
"{'a': A, b: B, c: C} != {'a': A, d: D, e: E}\n")
self.assertRegex(
err_str, r'(?ms).*^Unexpected, but present entries:\s+'
r'^(d: D$\s+^e: E|e: E$\s+^d: D)$')
self.assertRegex(
err_str, r'(?ms).*^repr\(\) of differing entries:\s+'
r'^.a.: A != A$', err_str)
self.assertRegex(
err_str, r'(?ms).*^Missing entries:\s+'
r'^(b: B$\s+^c: C|c: C$\s+^b: B)$')
else:
self.fail('Expecting AssertionError')
# Confirm that safe_repr, not repr, is being used.
class RaisesOnRepr(object):
def __repr__(self):
return 1/0 # Intentionally broken __repr__ implementation.
try:
self.assertDictEqual(
{RaisesOnRepr(): RaisesOnRepr()},
{RaisesOnRepr(): RaisesOnRepr()}
)
self.fail('Expected dicts not to match')
except AssertionError as e:
# Depending on the testing environment, the object may get a __main__
# prefix or a absltest_test prefix, so strip that for comparison.
error_msg = re.sub(
r'( at 0x[^>]+)|__main__\.|absltest_test\.', '', str(e))
self.assertRegex(error_msg, """(?m)\
{<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>} != \
{<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>}
Unexpected, but present entries:
<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>
Missing entries:
<.*RaisesOnRepr object.*>: <.*RaisesOnRepr object.*>
""")
# Confirm that safe_repr, not repr, is being used.
class RaisesOnLt(object):
def __lt__(self, unused_other):
raise TypeError('Object is unordered.')
def __repr__(self):
return ''
try:
self.assertDictEqual(
{RaisesOnLt(): RaisesOnLt()},
{RaisesOnLt(): RaisesOnLt()})
except AssertionError as e:
self.assertIn('Unexpected, but present entries:\n other.x
except AttributeError:
return NotImplemented
def __ge__(self, other):
try:
return self.x >= other.x
except AttributeError:
return NotImplemented
class B(A):
"""Like A, but not hashable."""
__hash__ = None
self.assertTotallyOrdered(
[A(1, 'a')],
[A(2, 'b')], # 2 is after 1.
[
A(3, 'c'),
B(3, 'd'),
B(3, 'e') # The second argument is irrelevant.
],
[A(4, 'z')])
# Invalid.
msg = 'This is a useful message'
whole_msg = '2 not less than 1 : This is a useful message'
self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
self.assertTotallyOrdered, [2], [1],
msg=msg)
self.assertRaises(AssertionError, self.assertTotallyOrdered, [2], [1])
self.assertRaises(AssertionError, self.assertTotallyOrdered, [2], [1], [3])
self.assertRaises(AssertionError, self.assertTotallyOrdered, [1, 2])
def test_short_description_without_docstring(self):
self.assertEqual(
self.shortDescription(),
'TestCaseTest.test_short_description_without_docstring',
)
def test_short_description_with_one_line_docstring(self):
"""Tests shortDescription() for a method with a docstring."""
self.assertEqual(
self.shortDescription(),
'TestCaseTest.test_short_description_with_one_line_docstring\n'
'Tests shortDescription() for a method with a docstring.',
)
def test_short_description_with_multi_line_docstring(self):
"""Tests shortDescription() for a method with a longer docstring.
This method ensures that only the first line of a docstring is
returned used in the short description, no matter how long the
whole thing is.
"""
self.assertEqual(
self.shortDescription(),
'TestCaseTest.test_short_description_with_multi_line_docstring\n'
'Tests shortDescription() for a method with a longer docstring.',
)
def test_assert_url_equal_same(self):
self.assertUrlEqual('http://a', 'http://a')
self.assertUrlEqual('http://a/path/test', 'http://a/path/test')
self.assertUrlEqual('#fragment', '#fragment')
self.assertUrlEqual('http://a/?q=1', 'http://a/?q=1')
self.assertUrlEqual('http://a/?q=1&v=5', 'http://a/?v=5&q=1')
self.assertUrlEqual('/logs?v=1&a=2&t=labels&f=path%3A%22foo%22',
'/logs?a=2&f=path%3A%22foo%22&v=1&t=labels')
self.assertUrlEqual('http://a/path;p1', 'http://a/path;p1')
self.assertUrlEqual('http://a/path;p2;p3;p1', 'http://a/path;p1;p2;p3')
self.assertUrlEqual('sip:alice@atlanta.com;maddr=239.255.255.1;ttl=15',
'sip:alice@atlanta.com;ttl=15;maddr=239.255.255.1')
self.assertUrlEqual('http://nyan/cat?p=1&b=', 'http://nyan/cat?b=&p=1')
def test_assert_url_equal_different(self):
msg = 'This is a useful message'
whole_msg = 'This is a useful message:\n- a\n+ b\n'
self.assertRaisesWithLiteralMatch(AssertionError, whole_msg,
self.assertUrlEqual,
'http://a', 'http://b', msg=msg)
self.assertRaises(AssertionError, self.assertUrlEqual,
'http://a/x', 'http://a:8080/x')
self.assertRaises(AssertionError, self.assertUrlEqual,
'http://a/x', 'http://a/y')
self.assertRaises(AssertionError, self.assertUrlEqual,
'http://a/?q=2', 'http://a/?q=1')
self.assertRaises(AssertionError, self.assertUrlEqual,
'http://a/?q=1&v=5', 'http://a/?v=2&q=1')
self.assertRaises(AssertionError, self.assertUrlEqual,
'http://a', 'sip://b')
self.assertRaises(AssertionError, self.assertUrlEqual,
'http://a#g', 'sip://a#f')
self.assertRaises(AssertionError, self.assertUrlEqual,
'http://a/path;p1;p3;p1', 'http://a/path;p1;p2;p3')
self.assertRaises(AssertionError, self.assertUrlEqual,
'http://nyan/cat?p=1&b=', 'http://nyan/cat?p=1')
def test_same_structure_same(self):
self.assertSameStructure(0, 0)
self.assertSameStructure(1, 1)
self.assertSameStructure('', '')
self.assertSameStructure('hello', 'hello', msg='This Should not fail')
self.assertSameStructure(set(), set())
self.assertSameStructure(set([1, 2]), set([1, 2]))
self.assertSameStructure(set(), frozenset())
self.assertSameStructure(set([1, 2]), frozenset([1, 2]))
self.assertSameStructure([], [])
self.assertSameStructure(['a'], ['a'])
self.assertSameStructure([], ())
self.assertSameStructure(['a'], ('a',))
self.assertSameStructure({}, {})
self.assertSameStructure({'one': 1}, {'one': 1})
self.assertSameStructure(collections.defaultdict(None, {'one': 1}),
{'one': 1})
self.assertSameStructure(collections.OrderedDict({'one': 1}),
collections.defaultdict(None, {'one': 1}))
def test_same_structure_different(self):
# Different type
with self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'int'> but b is a <(type|class) 'str'>"):
self.assertSameStructure(0, 'hello')
with self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'int'> but b is a <(type|class) 'list'>"):
self.assertSameStructure(0, [])
with self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'int'> but b is a <(type|class) 'float'>"):
self.assertSameStructure(2, 2.0)
with self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'list'> but b is a <(type|class) 'dict'>"):
self.assertSameStructure([], {})
with self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'list'> but b is a <(type|class) 'set'>"):
self.assertSameStructure([], set())
with self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'dict'> but b is a <(type|class) 'set'>"):
self.assertSameStructure({}, set())
# Different scalar values
self.assertRaisesWithLiteralMatch(
AssertionError, 'a is 0 but b is 1',
self.assertSameStructure, 0, 1)
self.assertRaisesWithLiteralMatch(
AssertionError, "a is 'hello' but b is 'goodbye' : This was expected",
self.assertSameStructure, 'hello', 'goodbye', msg='This was expected')
# Different sets
self.assertRaisesWithLiteralMatch(
AssertionError,
r'AA has 2 but BB does not',
self.assertSameStructure,
set([1, 2]),
set([1]),
aname='AA',
bname='BB')
self.assertRaisesWithLiteralMatch(
AssertionError,
r'AA lacks 2 but BB has it',
self.assertSameStructure,
set([1]),
set([1, 2]),
aname='AA',
bname='BB')
# Different lists
self.assertRaisesWithLiteralMatch(
AssertionError, "a has [2] with value 'z' but b does not",
self.assertSameStructure, ['x', 'y', 'z'], ['x', 'y'])
self.assertRaisesWithLiteralMatch(
AssertionError, "a lacks [2] but b has it with value 'z'",
self.assertSameStructure, ['x', 'y'], ['x', 'y', 'z'])
self.assertRaisesWithLiteralMatch(
AssertionError, "a[2] is 'z' but b[2] is 'Z'",
self.assertSameStructure, ['x', 'y', 'z'], ['x', 'y', 'Z'])
# Different dicts
self.assertRaisesWithLiteralMatch(
AssertionError, "a has ['two'] with value 2 but it's missing in b",
self.assertSameStructure, {'one': 1, 'two': 2}, {'one': 1})
self.assertRaisesWithLiteralMatch(
AssertionError, "a lacks ['two'] but b has it with value 2",
self.assertSameStructure, {'one': 1}, {'one': 1, 'two': 2})
self.assertRaisesWithLiteralMatch(
AssertionError, "a['two'] is 2 but b['two'] is 3",
self.assertSameStructure, {'one': 1, 'two': 2}, {'one': 1, 'two': 3})
# String and byte types should not be considered equivalent to other
# sequences
self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'list'> but b is a <(type|class) 'str'>",
self.assertSameStructure, [], '')
self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'str'> but b is a <(type|class) 'tuple'>",
self.assertSameStructure, '', ())
self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'list'> but b is a <(type|class) 'str'>",
self.assertSameStructure, ['a', 'b', 'c'], 'abc')
self.assertRaisesRegex(
AssertionError,
r"a is a <(type|class) 'str'> but b is a <(type|class) 'tuple'>",
self.assertSameStructure, 'abc', ('a', 'b', 'c'))
# Deep key generation
self.assertRaisesWithLiteralMatch(
AssertionError,
"a[0][0]['x']['y']['z'][0] is 1 but b[0][0]['x']['y']['z'][0] is 2",
self.assertSameStructure,
[[{'x': {'y': {'z': [1]}}}]], [[{'x': {'y': {'z': [2]}}}]])
# Multiple problems
self.assertRaisesWithLiteralMatch(
AssertionError,
'a[0] is 1 but b[0] is 3; a[1] is 2 but b[1] is 4',
self.assertSameStructure, [1, 2], [3, 4])
with self.assertRaisesRegex(
AssertionError,
re.compile(r"^a\[0] is 'a' but b\[0] is 'A'; .*"
r"a\[18] is 's' but b\[18] is 'S'; \.\.\.$")):
self.assertSameStructure(
list(string.ascii_lowercase), list(string.ascii_uppercase))
# Verify same behavior with self.maxDiff = None
self.maxDiff = None
self.assertRaisesWithLiteralMatch(
AssertionError,
'a[0] is 1 but b[0] is 3; a[1] is 2 but b[1] is 4',
self.assertSameStructure, [1, 2], [3, 4])
def test_same_structure_mapping_unchanged(self):
default_a = collections.defaultdict(lambda: 'BAD MODIFICATION', {})
dict_b = {'one': 'z'}
self.assertRaisesWithLiteralMatch(
AssertionError,
r"a lacks ['one'] but b has it with value 'z'",
self.assertSameStructure, default_a, dict_b)
self.assertEmpty(default_a)
dict_a = {'one': 'z'}
default_b = collections.defaultdict(lambda: 'BAD MODIFICATION', {})
self.assertRaisesWithLiteralMatch(
AssertionError,
r"a has ['one'] with value 'z' but it's missing in b",
self.assertSameStructure, dict_a, default_b)
self.assertEmpty(default_b)
def test_same_structure_uses_type_equality_func_for_leaves(self):
class CustomLeaf(object):
def __init__(self, n):
self.n = n
def __repr__(self):
return f'CustomLeaf({self.n})'
def assert_custom_leaf_equal(a, b, msg):
del msg
assert a.n % 5 == b.n % 5
self.addTypeEqualityFunc(CustomLeaf, assert_custom_leaf_equal)
self.assertSameStructure(CustomLeaf(4), CustomLeaf(9))
self.assertRaisesWithLiteralMatch(
AssertionError,
r'a is CustomLeaf(4) but b is CustomLeaf(8)',
self.assertSameStructure, CustomLeaf(4), CustomLeaf(8),
)
def test_assert_json_equal_same(self):
self.assertJsonEqual('{"success": true}', '{"success": true}')
self.assertJsonEqual('{"success": true}', '{"success":true}')
self.assertJsonEqual('true', 'true')
self.assertJsonEqual('null', 'null')
self.assertJsonEqual('false', 'false')
self.assertJsonEqual('34', '34')
self.assertJsonEqual('[1, 2, 3]', '[1,2,3]', msg='please PASS')
self.assertJsonEqual('{"sequence": [1, 2, 3], "float": 23.42}',
'{"float": 23.42, "sequence": [1,2,3]}')
self.assertJsonEqual('{"nest": {"spam": "eggs"}, "float": 23.42}',
'{"float": 23.42, "nest": {"spam":"eggs"}}')
def test_assert_json_equal_different(self):
with self.assertRaises(AssertionError):
self.assertJsonEqual('{"success": true}', '{"success": false}')
with self.assertRaises(AssertionError):
self.assertJsonEqual('{"success": false}', '{"Success": false}')
with self.assertRaises(AssertionError):
self.assertJsonEqual('false', 'true')
with self.assertRaises(AssertionError) as error_context:
self.assertJsonEqual('null', '0', msg='I demand FAILURE')
self.assertIn('I demand FAILURE', error_context.exception.args[0])
self.assertIn('None', error_context.exception.args[0])
with self.assertRaises(AssertionError):
self.assertJsonEqual('[1, 0, 3]', '[1,2,3]')
with self.assertRaises(AssertionError):
self.assertJsonEqual('{"sequence": [1, 2, 3], "float": 23.42}',
'{"float": 23.42, "sequence": [1,0,3]}')
with self.assertRaises(AssertionError):
self.assertJsonEqual('{"nest": {"spam": "eggs"}, "float": 23.42}',
'{"float": 23.42, "nest": {"Spam":"beans"}}')
def test_assert_json_equal_bad_json(self):
with self.assertRaises(ValueError) as error_context:
self.assertJsonEqual("alhg'2;#", '{"a": true}')
self.assertIn('first', error_context.exception.args[0])
self.assertIn('alhg', error_context.exception.args[0])
with self.assertRaises(ValueError) as error_context:
self.assertJsonEqual('{"a": true}', "alhg'2;#")
self.assertIn('second', error_context.exception.args[0])
self.assertIn('alhg', error_context.exception.args[0])
with self.assertRaises(ValueError) as error_context:
self.assertJsonEqual('', '')
class GetCommandStderrTestCase(absltest.TestCase):
def test_return_status(self):
tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
returncode = (
absltest.get_command_stderr(
['cat', os.path.join(tmpdir, 'file.txt')],
env=_env_for_command_tests())[0])
self.assertEqual(1, returncode)
def test_stderr(self):
tmpdir = tempfile.mkdtemp(dir=absltest.TEST_TMPDIR.value)
stderr = (
absltest.get_command_stderr(
['cat', os.path.join(tmpdir, 'file.txt')],
env=_env_for_command_tests())[1])
stderr = stderr.decode('utf-8')
self.assertRegex(stderr, 'No such file or directory')
@contextlib.contextmanager
def cm_for_test(obj):
try:
obj.cm_state = 'yielded'
yield 'value'
finally:
obj.cm_state = 'exited'
class EnterContextTest(absltest.TestCase):
def setUp(self):
self.cm_state = 'unset'
self.cm_value = 'unset'
def assert_cm_exited():
self.assertEqual(self.cm_state, 'exited')
# Because cleanup functions are run in reverse order, we have to add
# our assert-cleanup before the exit stack registers its own cleanup.
# This ensures we see state after the stack cleanup runs.
self.addCleanup(assert_cm_exited)
super(EnterContextTest, self).setUp()
self.cm_value = self.enter_context(cm_for_test(self))
def test_enter_context(self):
self.assertEqual(self.cm_value, 'value')
self.assertEqual(self.cm_state, 'yielded')
@absltest.skipIf(not hasattr(absltest.TestCase, 'addClassCleanup'),
'Python 3.8 required for class-level enter_context')
class EnterContextClassmethodTest(absltest.TestCase):
cm_state = 'unset'
cm_value = 'unset'
@classmethod
def setUpClass(cls):
def assert_cm_exited():
assert cls.cm_state == 'exited'
# Because cleanup functions are run in reverse order, we have to add
# our assert-cleanup before the exit stack registers its own cleanup.
# This ensures we see state after the stack cleanup runs.
cls.addClassCleanup(assert_cm_exited)
super(EnterContextClassmethodTest, cls).setUpClass()
cls.cm_value = cls.enter_context(cm_for_test(cls))
def test_enter_context(self):
self.assertEqual(self.cm_value, 'value')
self.assertEqual(self.cm_state, 'yielded')
class EqualityAssertionTest(absltest.TestCase):
"""This test verifies that absltest.failIfEqual actually tests __ne__.
If a user class implements __eq__, unittest.assertEqual will call it
via first == second. However, failIfEqual also calls
first == second. This means that while the caller may believe
their __ne__ method is being tested, it is not.
"""
class NeverEqual(object):
"""Objects of this class behave like NaNs."""
def __eq__(self, unused_other):
return False
def __ne__(self, unused_other):
return False
class AllSame(object):
"""All objects of this class compare as equal."""
def __eq__(self, unused_other):
return True
def __ne__(self, unused_other):
return False
class EqualityTestsWithEq(object):
"""Performs all equality and inequality tests with __eq__."""
def __init__(self, value):
self._value = value
def __eq__(self, other):
return self._value == other._value
def __ne__(self, other):
return not self.__eq__(other)
class EqualityTestsWithNe(object):
"""Performs all equality and inequality tests with __ne__."""
def __init__(self, value):
self._value = value
def __eq__(self, other):
return not self.__ne__(other)
def __ne__(self, other):
return self._value != other._value
class EqualityTestsWithCmp(object):
def __init__(self, value):
self._value = value
def __cmp__(self, other):
return cmp(self._value, other._value)
class EqualityTestsWithLtEq(object):
def __init__(self, value):
self._value = value
def __eq__(self, other):
return self._value == other._value
def __lt__(self, other):
return self._value < other._value
def test_all_comparisons_fail(self):
i1 = self.NeverEqual()
i2 = self.NeverEqual()
self.assertFalse(i1 == i2)
self.assertFalse(i1 != i2)
# Compare two distinct objects
self.assertFalse(i1 is i2)
self.assertRaises(AssertionError, self.assertEqual, i1, i2)
self.assertRaises(AssertionError, self.assertNotEqual, i1, i2)
# A NeverEqual object should not compare equal to itself either.
i2 = i1
self.assertTrue(i1 is i2)
self.assertFalse(i1 == i2)
self.assertFalse(i1 != i2)
self.assertRaises(AssertionError, self.assertEqual, i1, i2)
self.assertRaises(AssertionError, self.assertNotEqual, i1, i2)
def test_all_comparisons_succeed(self):
a = self.AllSame()
b = self.AllSame()
self.assertFalse(a is b)
self.assertTrue(a == b)
self.assertFalse(a != b)
self.assertEqual(a, b)
self.assertRaises(AssertionError, self.assertNotEqual, a, b)
def _perform_apple_apple_orange_checks(self, same_a, same_b, different):
"""Perform consistency checks with two apples and an orange.
The two apples should always compare as being the same (and inequality
checks should fail). The orange should always compare as being different
to each of the apples.
Args:
same_a: the first apple
same_b: the second apple
different: the orange
"""
self.assertTrue(same_a == same_b)
self.assertFalse(same_a != same_b)
self.assertEqual(same_a, same_b)
self.assertFalse(same_a == different)
self.assertTrue(same_a != different)
self.assertNotEqual(same_a, different)
self.assertFalse(same_b == different)
self.assertTrue(same_b != different)
self.assertNotEqual(same_b, different)
def test_comparison_with_eq(self):
same_a = self.EqualityTestsWithEq(42)
same_b = self.EqualityTestsWithEq(42)
different = self.EqualityTestsWithEq(1769)
self._perform_apple_apple_orange_checks(same_a, same_b, different)
def test_comparison_with_ne(self):
same_a = self.EqualityTestsWithNe(42)
same_b = self.EqualityTestsWithNe(42)
different = self.EqualityTestsWithNe(1769)
self._perform_apple_apple_orange_checks(same_a, same_b, different)
def test_comparison_with_cmp_or_lt_eq(self):
same_a = self.EqualityTestsWithLtEq(42)
same_b = self.EqualityTestsWithLtEq(42)
different = self.EqualityTestsWithLtEq(1769)
self._perform_apple_apple_orange_checks(same_a, same_b, different)
class AssertSequenceStartsWithTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.a = [5, 'foo', {'c': 'd'}, None]
def test_empty_sequence_starts_with_empty_prefix(self):
self.assertSequenceStartsWith([], ())
def test_sequence_prefix_is_an_empty_list(self):
self.assertSequenceStartsWith([[]], ([], 'foo'))
def test_raise_if_empty_prefix_with_non_empty_whole(self):
with self.assertRaisesRegex(
AssertionError, 'Prefix length is 0 but whole length is %d: %s' % (len(
self.a), r"\[5, 'foo', \{'c': 'd'\}, None\]")):
self.assertSequenceStartsWith([], self.a)
def test_single_element_prefix(self):
self.assertSequenceStartsWith([5], self.a)
def test_two_element_prefix(self):
self.assertSequenceStartsWith((5, 'foo'), self.a)
def test_prefix_is_full_sequence(self):
self.assertSequenceStartsWith([5, 'foo', {'c': 'd'}, None], self.a)
def test_string_prefix(self):
self.assertSequenceStartsWith('abc', 'abc123')
def test_convert_non_sequence_prefix_to_sequence_and_try_again(self):
self.assertSequenceStartsWith(5, self.a)
def test_whole_not_asequence(self):
msg = (r'For whole: len\(5\) is not supported, it appears to be type: '
'<(type|class) \'int\'>')
with self.assertRaisesRegex(AssertionError, msg):
self.assertSequenceStartsWith(self.a, 5)
def test_raise_if_sequence_does_not_start_with_prefix(self):
msg = (r"prefix: \['foo', \{'c': 'd'\}\] not found at start of whole: "
r"\[5, 'foo', \{'c': 'd'\}, None\].")
with self.assertRaisesRegex(AssertionError, msg):
self.assertSequenceStartsWith(['foo', {'c': 'd'}], self.a)
@parameterized.named_parameters(
('dict', {'a': 1, 2: 'b'}, {'a': 1, 2: 'b', 'c': '3'}),
('set', {1, 2}, {1, 2, 3}),
)
def test_raise_if_set_or_dict(self, prefix, whole):
with self.assertRaisesRegex(
AssertionError, 'For whole: Mapping or Set objects are not supported'
):
self.assertSequenceStartsWith(prefix, whole)
class TestAssertEmpty(absltest.TestCase):
longMessage = True
def test_raises_if_not_asized_object(self):
msg = "Expected a Sized object, got: 'int'"
with self.assertRaisesRegex(AssertionError, msg):
self.assertEmpty(1)
def test_calls_len_not_bool(self):
class BadList(list):
def __bool__(self):
return False
__nonzero__ = __bool__
bad_list = BadList()
self.assertEmpty(bad_list)
self.assertFalse(bad_list)
def test_passes_when_empty(self):
empty_containers = [
list(),
tuple(),
dict(),
set(),
frozenset(),
b'',
u'',
bytearray(),
]
for container in empty_containers:
self.assertEmpty(container)
def test_raises_with_not_empty_containers(self):
not_empty_containers = [
[1],
(1,),
{'foo': 'bar'},
{1},
frozenset([1]),
b'a',
u'a',
bytearray(b'a'),
]
regexp = r'.* has length of 1\.$'
for container in not_empty_containers:
with self.assertRaisesRegex(AssertionError, regexp):
self.assertEmpty(container)
def test_user_message_added_to_default(self):
msg = 'This is a useful message'
whole_msg = re.escape('[1] has length of 1. : This is a useful message')
with self.assertRaisesRegex(AssertionError, whole_msg):
self.assertEmpty([1], msg=msg)
class TestAssertNotEmpty(absltest.TestCase):
longMessage = True
def test_raises_if_not_asized_object(self):
msg = "Expected a Sized object, got: 'int'"
with self.assertRaisesRegex(AssertionError, msg):
self.assertNotEmpty(1)
def test_calls_len_not_bool(self):
class BadList(list):
def __bool__(self):
return False
__nonzero__ = __bool__
bad_list = BadList([1])
self.assertNotEmpty(bad_list)
self.assertFalse(bad_list)
def test_passes_when_not_empty(self):
not_empty_containers = [
[1],
(1,),
{'foo': 'bar'},
{1},
frozenset([1]),
b'a',
u'a',
bytearray(b'a'),
]
for container in not_empty_containers:
self.assertNotEmpty(container)
def test_raises_with_empty_containers(self):
empty_containers = [
list(),
tuple(),
dict(),
set(),
frozenset(),
b'',
u'',
bytearray(),
]
regexp = r'.* has length of 0\.$'
for container in empty_containers:
with self.assertRaisesRegex(AssertionError, regexp):
self.assertNotEmpty(container)
def test_user_message_added_to_default(self):
msg = 'This is a useful message'
whole_msg = re.escape('[] has length of 0. : This is a useful message')
with self.assertRaisesRegex(AssertionError, whole_msg):
self.assertNotEmpty([], msg=msg)
class TestAssertLen(absltest.TestCase):
longMessage = True
def test_raises_if_not_asized_object(self):
msg = "Expected a Sized object, got: 'int'"
with self.assertRaisesRegex(AssertionError, msg):
self.assertLen(1, 1)
def test_passes_when_expected_len(self):
containers = [
[[1], 1],
[(1, 2), 2],
[{'a': 1, 'b': 2, 'c': 3}, 3],
[{1, 2, 3, 4}, 4],
[frozenset([1]), 1],
[b'abc', 3],
[u'def', 3],
[bytearray(b'ghij'), 4],
]
for container, expected_len in containers:
self.assertLen(container, expected_len)
def test_raises_when_unexpected_len(self):
containers = [
[1],
(1, 2),
{'a': 1, 'b': 2, 'c': 3},
{1, 2, 3, 4},
frozenset([1]),
b'abc',
u'def',
bytearray(b'ghij'),
]
for container in containers:
regexp = r'.* has length of %d, expected 100\.$' % len(container)
with self.assertRaisesRegex(AssertionError, regexp):
self.assertLen(container, 100)
def test_user_message_added_to_default(self):
msg = 'This is a useful message'
whole_msg = (
r'\[1\] has length of 1, expected 100. : This is a useful message')
with self.assertRaisesRegex(AssertionError, whole_msg):
self.assertLen([1], 100, msg)
class TestLoaderTest(absltest.TestCase):
"""Tests that the TestLoader bans methods named TestFoo."""
# pylint: disable=invalid-name
class Valid(absltest.TestCase):
"""Test case containing a variety of valid names."""
test_property = 1
TestProperty = 2
@staticmethod
def TestStaticMethod():
pass
@staticmethod
def TestStaticMethodWithArg(foo):
pass
@classmethod
def TestClassMethod(cls):
pass
def Test(self):
pass
def TestingHelper(self):
pass
def testMethod(self):
pass
def TestHelperWithParams(self, a, b):
pass
def TestHelperWithVarargs(self, *args, **kwargs):
pass
def TestHelperWithDefaults(self, a=5):
pass
def TestHelperWithKeywordOnly(self, *, arg):
pass
class Invalid(absltest.TestCase):
"""Test case containing a suspicious method."""
def testMethod(self):
pass
def TestSuspiciousMethod(self):
pass
# pylint: enable=invalid-name
def setUp(self):
self.loader = absltest.TestLoader()
def test_valid(self):
suite = self.loader.loadTestsFromTestCase(TestLoaderTest.Valid)
self.assertEqual(1, suite.countTestCases())
def testInvalid(self):
with self.assertRaisesRegex(TypeError, 'TestSuspiciousMethod'):
self.loader.loadTestsFromTestCase(TestLoaderTest.Invalid)
class InitNotNecessaryForAssertsTest(absltest.TestCase):
"""TestCase assertions should work even if __init__ wasn't correctly called.
This is a workaround, see comment in
absltest.TestCase._getAssertEqualityFunc. We know that not calling
__init__ of a superclass is a bad thing, but people keep doing them,
and this (even if a little bit dirty) saves them from shooting
themselves in the foot.
"""
def test_subclass(self):
class Subclass(absltest.TestCase):
def __init__(self): # pylint: disable=super-init-not-called
pass
Subclass().assertEqual({}, {})
def test_multiple_inheritance(self):
class Foo(object):
def __init__(self, *args, **kwargs):
pass
class Subclass(Foo, absltest.TestCase):
pass
Subclass().assertEqual({}, {})
@dataclasses.dataclass
class _ExampleDataclass:
comparable: str
not_comparable: str = dataclasses.field(compare=False)
comparable2: str = 'comparable2'
@dataclasses.dataclass
class _ExampleCustomEqualDataclass:
value: str
def __eq__(self, other):
return False
class TestAssertDataclassEqual(absltest.TestCase):
def test_assert_dataclass_equal_checks_a_for_dataclass(self):
b = _ExampleDataclass('a', 'b')
message = 'First argument is not a dataclass instance.'
with self.assertRaisesWithLiteralMatch(AssertionError, message):
self.assertDataclassEqual('a', b)
def test_assert_dataclass_equal_checks_b_for_dataclass(self):
a = _ExampleDataclass('a', 'b')
message = 'Second argument is not a dataclass instance.'
with self.assertRaisesWithLiteralMatch(AssertionError, message):
self.assertDataclassEqual(a, 'b')
def test_assert_dataclass_equal_different_dataclasses(self):
a = _ExampleDataclass('a', 'b')
b = _ExampleCustomEqualDataclass('c')
message = """Found different dataclass types: != """
with self.assertRaisesWithLiteralMatch(AssertionError, message):
self.assertDataclassEqual(a, b)
def test_assert_dataclass_equal(self):
a = _ExampleDataclass(comparable='a', not_comparable='b')
b = _ExampleDataclass(comparable='a', not_comparable='c')
self.assertDataclassEqual(a, a)
self.assertDataclassEqual(a, b)
self.assertDataclassEqual(b, a)
def test_assert_dataclass_fails_non_equal_classes_assert_dict_passes(self):
a = _ExampleCustomEqualDataclass(value='a')
b = _ExampleCustomEqualDataclass(value='a')
message = textwrap.dedent("""\
_ExampleCustomEqualDataclass(value='a') != _ExampleCustomEqualDataclass(value='a')
Cannot detect difference by examining the fields of the dataclass.""")
with self.assertRaisesWithLiteralMatch(AssertionError, message):
self.assertDataclassEqual(a, b)
def test_assert_dataclass_fails_assert_dict_fails_one_field(self):
a = _ExampleDataclass(comparable='a', not_comparable='b')
b = _ExampleDataclass(comparable='c', not_comparable='d')
message = textwrap.dedent("""\
_ExampleDataclass(comparable='a', not_comparable='b', comparable2='comparable2') != _ExampleDataclass(comparable='c', not_comparable='d', comparable2='comparable2')
Fields that differ:
comparable: 'a' != 'c'""")
with self.assertRaisesWithLiteralMatch(AssertionError, message):
self.assertDataclassEqual(a, b)
def test_assert_dataclass_fails_assert_dict_fails_multiple_fields(self):
a = _ExampleDataclass(comparable='a', not_comparable='b', comparable2='c')
b = _ExampleDataclass(comparable='c', not_comparable='d', comparable2='e')
message = textwrap.dedent("""\
_ExampleDataclass(comparable='a', not_comparable='b', comparable2='c') != _ExampleDataclass(comparable='c', not_comparable='d', comparable2='e')
Fields that differ:
comparable: 'a' != 'c'
comparable2: 'c' != 'e'""")
with self.assertRaisesWithLiteralMatch(AssertionError, message):
self.assertDataclassEqual(a, b)
class GetCommandStringTest(parameterized.TestCase):
@parameterized.parameters(
([], '', ''),
([''], "''", ''),
(['command', 'arg-0'], "'command' 'arg-0'", 'command arg-0'),
([u'command', u'arg-0'], "'command' 'arg-0'", u'command arg-0'),
(["foo'bar"], "'foo'\"'\"'bar'", "foo'bar"),
(['foo"bar'], "'foo\"bar'", 'foo"bar'),
('command arg-0', 'command arg-0', 'command arg-0'),
(u'command arg-0', 'command arg-0', 'command arg-0'))
def test_get_command_string(
self, command, expected_non_windows, expected_windows):
expected = expected_windows if os.name == 'nt' else expected_non_windows
self.assertEqual(expected, absltest.get_command_string(command))
class TempFileTest(BaseTestCase):
def assert_dir_exists(self, temp_dir):
path = temp_dir.full_path
self.assertTrue(os.path.exists(path), 'Dir {} does not exist'.format(path))
self.assertTrue(os.path.isdir(path),
'Path {} exists, but is not a directory'.format(path))
def assert_file_exists(self, temp_file, expected_content=b''):
path = temp_file.full_path
self.assertTrue(os.path.exists(path), 'File {} does not exist'.format(path))
self.assertTrue(os.path.isfile(path),
'Path {} exists, but is not a file'.format(path))
mode = 'rb' if isinstance(expected_content, bytes) else 'rt'
with io.open(path, mode) as fp:
actual = fp.read()
self.assertEqual(expected_content, actual)
def run_tempfile_helper(self, cleanup, expected_paths):
tmpdir = self.create_tempdir('helper-test-temp-dir')
env = {
'ABSLTEST_TEST_HELPER_TEMPFILE_CLEANUP': cleanup,
'TEST_TMPDIR': tmpdir.full_path,
}
stdout, stderr, _ = self.run_helper(
0, ['TempFileHelperTest'], env, expect_success=False
)
output = ('\n=== Helper output ===\n'
'----- stdout -----\n{}\n'
'----- end stdout -----\n'
'----- stderr -----\n{}\n'
'----- end stderr -----\n'
'===== end helper output =====').format(stdout, stderr)
self.assertIn('test_failure', stderr, output)
# Adjust paths to match on Windows
expected_paths = {path.replace('/', os.sep) for path in expected_paths}
actual = {
os.path.relpath(f, tmpdir.full_path)
for f in _listdir_recursive(tmpdir.full_path)
if f != tmpdir.full_path
}
self.assertEqual(expected_paths, actual, output)
def test_create_file_pre_existing_readonly(self):
first = self.create_tempfile('foo', content='first')
os.chmod(first.full_path, 0o444)
second = self.create_tempfile('foo', content='second')
self.assertEqual('second', first.read_text())
self.assertEqual('second', second.read_text())
def test_create_file_fails_cleanup(self):
path = self.create_tempfile().full_path
# Removing the write bit from the file makes it undeletable on Windows.
os.chmod(path, 0)
# Removing the write bit from the whole directory makes all contained files
# undeletable on unix. We also need it to be exec so that os.path.isfile
# returns true, and we reach the buggy branch.
os.chmod(os.path.dirname(path), stat.S_IEXEC)
# The test should pass, even though that file cannot be deleted in teardown.
def test_temp_file_path_like(self):
tempdir = self.create_tempdir('foo')
tempfile_ = tempdir.create_file('bar')
self.assertEqual(tempfile_.read_text(), pathlib.Path(tempfile_).read_text())
# assertIsInstance causes the types to be narrowed, so calling create_file
# and read_text() must be done before these assertions to avoid type errors.
self.assertIsInstance(tempdir, os.PathLike)
self.assertIsInstance(tempfile_, os.PathLike)
def test_unnamed(self):
td = self.create_tempdir()
self.assert_dir_exists(td)
tdf = td.create_file()
self.assert_file_exists(tdf)
tdd = td.mkdir()
self.assert_dir_exists(tdd)
tf = self.create_tempfile()
self.assert_file_exists(tf)
def test_named(self):
td = self.create_tempdir('d')
self.assert_dir_exists(td)
tdf = td.create_file('df')
self.assert_file_exists(tdf)
tdd = td.mkdir('dd')
self.assert_dir_exists(tdd)
tf = self.create_tempfile('f')
self.assert_file_exists(tf)
def test_nested_paths(self):
td = self.create_tempdir('d1/d2')
self.assert_dir_exists(td)
tdf = td.create_file('df1/df2')
self.assert_file_exists(tdf)
tdd = td.mkdir('dd1/dd2')
self.assert_dir_exists(tdd)
tf = self.create_tempfile('f1/f2')
self.assert_file_exists(tf)
def test_tempdir_create_file(self):
td = self.create_tempdir()
td.create_file(content='text')
def test_tempfile_text(self):
tf = self.create_tempfile(content='text')
self.assert_file_exists(tf, 'text')
self.assertEqual('text', tf.read_text())
with tf.open_text() as fp:
self.assertEqual('text', fp.read())
with tf.open_text('w') as fp:
fp.write(u'text-from-open-write')
self.assertEqual('text-from-open-write', tf.read_text())
tf.write_text('text-from-write-text')
self.assertEqual('text-from-write-text', tf.read_text())
def test_tempfile_bytes(self):
tf = self.create_tempfile(content=b'\x00\x01\x02')
self.assert_file_exists(tf, b'\x00\x01\x02')
self.assertEqual(b'\x00\x01\x02', tf.read_bytes())
with tf.open_bytes() as fp:
self.assertEqual(b'\x00\x01\x02', fp.read())
with tf.open_bytes('wb') as fp:
fp.write(b'\x03')
self.assertEqual(b'\x03', tf.read_bytes())
tf.write_bytes(b'\x04')
self.assertEqual(b'\x04', tf.read_bytes())
def test_tempdir_same_name(self):
"""Make sure the same directory name can be used."""
td1 = self.create_tempdir('foo')
td2 = self.create_tempdir('foo')
self.assert_dir_exists(td1)
self.assert_dir_exists(td2)
def test_tempfile_cleanup_success(self):
expected = {
'TempFileHelperTest',
'TempFileHelperTest/test_failure',
'TempFileHelperTest/test_failure/failure',
'TempFileHelperTest/test_success',
'TempFileHelperTest/test_subtest_failure',
'TempFileHelperTest/test_subtest_failure/parent',
'TempFileHelperTest/test_subtest_failure/successful_child',
'TempFileHelperTest/test_subtest_failure/failed_child',
'TempFileHelperTest/test_subtest_success',
}
self.run_tempfile_helper('SUCCESS', expected)
def test_tempfile_cleanup_always(self):
expected = {
'TempFileHelperTest',
'TempFileHelperTest/test_failure',
'TempFileHelperTest/test_success',
'TempFileHelperTest/test_subtest_failure',
'TempFileHelperTest/test_subtest_success',
}
self.run_tempfile_helper('ALWAYS', expected)
def test_tempfile_cleanup_off(self):
expected = {
'TempFileHelperTest',
'TempFileHelperTest/test_failure',
'TempFileHelperTest/test_failure/failure',
'TempFileHelperTest/test_success',
'TempFileHelperTest/test_success/success',
'TempFileHelperTest/test_subtest_failure',
'TempFileHelperTest/test_subtest_failure/parent',
'TempFileHelperTest/test_subtest_failure/successful_child',
'TempFileHelperTest/test_subtest_failure/failed_child',
'TempFileHelperTest/test_subtest_success',
'TempFileHelperTest/test_subtest_success/parent',
'TempFileHelperTest/test_subtest_success/child0',
'TempFileHelperTest/test_subtest_success/child1',
}
self.run_tempfile_helper('OFF', expected)
class SkipClassTest(absltest.TestCase):
def test_incorrect_decorator_call(self):
with self.assertRaises(TypeError):
# Disabling type checking because pytype correctly picks up that
# @absltest.skipThisClass is being used incorrectly.
# pytype: disable=wrong-arg-types
@absltest.skipThisClass
class Test(absltest.TestCase): # pylint: disable=unused-variable
pass
# pytype: enable=wrong-arg-types
def test_incorrect_decorator_subclass(self):
with self.assertRaises(TypeError):
@absltest.skipThisClass('reason')
def test_method(): # pylint: disable=unused-variable
pass
def test_correct_decorator_class(self):
@absltest.skipThisClass('reason')
class Test(absltest.TestCase):
pass
with self.assertRaises(absltest.SkipTest):
Test.setUpClass()
def test_correct_decorator_subclass(self):
@absltest.skipThisClass('reason')
class Test(absltest.TestCase):
pass
class Subclass(Test):
pass
with self.subTest('Base class should be skipped'):
with self.assertRaises(absltest.SkipTest):
Test.setUpClass()
with self.subTest('Subclass should not be skipped'):
Subclass.setUpClass() # should not raise.
def test_setup(self):
@absltest.skipThisClass('reason')
class Test(absltest.TestCase):
@classmethod
def setUpClass(cls):
super(Test, cls).setUpClass()
cls.foo = 1
class Subclass(Test):
pass
Subclass.setUpClass()
self.assertEqual(Subclass.foo, 1)
def test_setup_chain(self):
@absltest.skipThisClass('reason')
class BaseTest(absltest.TestCase):
foo: int
@classmethod
def setUpClass(cls):
super(BaseTest, cls).setUpClass()
cls.foo = 1
@absltest.skipThisClass('reason')
class SecondBaseTest(BaseTest):
@classmethod
def setUpClass(cls):
super(SecondBaseTest, cls).setUpClass()
cls.bar = 2
class Subclass(SecondBaseTest):
pass
Subclass.setUpClass()
self.assertEqual(Subclass.foo, 1)
self.assertEqual(Subclass.bar, 2)
def test_setup_args(self):
@absltest.skipThisClass('reason')
class Test(absltest.TestCase):
foo: str
bar: Optional[str]
@classmethod
def setUpClass(cls, foo, bar=None):
super(Test, cls).setUpClass()
cls.foo = foo
cls.bar = bar
class Subclass(Test):
@classmethod
def setUpClass(cls):
super(Subclass, cls).setUpClass('foo', bar='baz')
Subclass.setUpClass()
self.assertEqual(Subclass.foo, 'foo')
self.assertEqual(Subclass.bar, 'baz')
def test_setup_multiple_inheritance(self):
# Test that skipping this class doesn't break the MRO chain and stop
# RequiredBase.setUpClass from running.
@absltest.skipThisClass('reason')
class Left(absltest.TestCase):
pass
class RequiredBase(absltest.TestCase):
foo: str
@classmethod
def setUpClass(cls):
super(RequiredBase, cls).setUpClass()
cls.foo = 'foo'
class Right(RequiredBase):
@classmethod
def setUpClass(cls):
super(Right, cls).setUpClass()
# Test will fail unless Left.setUpClass() follows mro properly
# Right.setUpClass()
class Subclass(Left, Right):
@classmethod
def setUpClass(cls):
super(Subclass, cls).setUpClass()
class Test(Subclass):
pass
Test.setUpClass()
self.assertEqual(Test.foo, 'foo')
def test_skip_class(self):
@absltest.skipThisClass('reason')
class BaseTest(absltest.TestCase):
def test_foo(self):
_ = 1 / 0
class Test(BaseTest):
def test_foo(self):
self.assertEqual(1, 1)
with self.subTest('base class'):
ts = unittest.makeSuite(BaseTest)
self.assertEqual(1, ts.countTestCases())
res = unittest.TestResult()
ts.run(res)
self.assertTrue(res.wasSuccessful())
self.assertLen(res.skipped, 1)
self.assertEqual(0, res.testsRun)
self.assertEmpty(res.failures)
self.assertEmpty(res.errors)
with self.subTest('real test'):
ts = unittest.makeSuite(Test)
self.assertEqual(1, ts.countTestCases())
res = unittest.TestResult()
ts.run(res)
self.assertTrue(res.wasSuccessful())
self.assertEqual(1, res.testsRun)
self.assertEmpty(res.skipped)
self.assertEmpty(res.failures)
self.assertEmpty(res.errors)
def test_skip_class_unittest(self):
@absltest.skipThisClass('reason')
class Test(unittest.TestCase): # note: unittest not absltest
def test_foo(self):
_ = 1 / 0
ts = unittest.makeSuite(Test)
self.assertEqual(1, ts.countTestCases())
res = unittest.TestResult()
ts.run(res)
self.assertTrue(res.wasSuccessful())
self.assertLen(res.skipped, 1)
self.assertEqual(0, res.testsRun)
self.assertEmpty(res.failures)
self.assertEmpty(res.errors)
class ExitCodeTest(BaseTestCase):
def test_exits_5_when_no_tests(self):
expect_success = sys.version_info < (3, 12)
_, _, exit_code = self.run_helper(
None,
[],
{},
expect_success=expect_success,
helper_name='absltest_test_helper_skipped',
)
if not expect_success:
self.assertEqual(exit_code, 5)
def test_exits_5_when_all_skipped(self):
self.run_helper(
None,
[],
{'ABSLTEST_TEST_HELPER_DEFINE_CLASS': '1'},
expect_success=True,
helper_name='absltest_test_helper_skipped',
)
def _listdir_recursive(path):
for dirname, _, filenames in os.walk(path):
yield dirname
for filename in filenames:
yield os.path.join(dirname, filename)
def _env_for_command_tests():
if os.name == 'nt' and 'PATH' in os.environ:
# get_command_stderr and assertCommandXXX don't inherit environment
# variables by default. This makes sure msys commands can be found on
# Windows.
return {'PATH': os.environ['PATH']}
else:
return None
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_test_helper.py 0000664 0000000 0000000 00000011327 14551576331 0023746 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper binary for absltest_test.py."""
import os
import tempfile
import unittest
from absl import app
from absl import flags
from absl.testing import absltest
FLAGS = flags.FLAGS
_TEST_ID = flags.DEFINE_integer('test_id', 0, 'Which test to run.')
_NAME = flags.DEFINE_multi_string('name', [], 'List of names to print.')
@flags.validator('name')
def validate_name(value):
# This validator makes sure that the second FLAGS(sys.argv) inside
# absltest.main() won't actually trigger side effects of the flag parsing.
if len(value) > 2:
raise flags.ValidationError(
f'No more than two names should be specified, found {len(value)} names')
return True
class HelperTest(absltest.TestCase):
def test_flags(self):
if _TEST_ID.value == 1:
self.assertEqual(FLAGS.test_random_seed, 301)
if os.name == 'nt':
# On Windows, it's always in the temp dir, which doesn't start with '/'.
expected_prefix = tempfile.gettempdir()
else:
expected_prefix = '/'
self.assertTrue(
absltest.TEST_TMPDIR.value.startswith(expected_prefix),
'--test_tmpdir={} does not start with {}'.format(
absltest.TEST_TMPDIR.value, expected_prefix))
self.assertTrue(os.access(absltest.TEST_TMPDIR.value, os.W_OK))
elif _TEST_ID.value == 2:
self.assertEqual(FLAGS.test_random_seed, 321)
self.assertEqual(
absltest.TEST_SRCDIR.value,
os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR'])
self.assertEqual(
absltest.TEST_TMPDIR.value,
os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR'])
elif _TEST_ID.value == 3:
self.assertEqual(FLAGS.test_random_seed, 123)
self.assertEqual(
absltest.TEST_SRCDIR.value,
os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR'])
self.assertEqual(
absltest.TEST_TMPDIR.value,
os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR'])
elif _TEST_ID.value == 4:
self.assertEqual(FLAGS.test_random_seed, 221)
self.assertEqual(
absltest.TEST_SRCDIR.value,
os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_SRCDIR'])
self.assertEqual(
absltest.TEST_TMPDIR.value,
os.environ['ABSLTEST_TEST_HELPER_EXPECTED_TEST_TMPDIR'])
else:
raise unittest.SkipTest(
'Not asked to run: --test_id={}'.format(_TEST_ID.value))
@unittest.expectedFailure
def test_expected_failure(self):
if _TEST_ID.value == 5:
self.assertEqual(1, 1) # Expected failure, got success.
else:
self.assertEqual(1, 2) # The expected failure.
def test_xml_env_vars(self):
if _TEST_ID.value == 6:
self.assertEqual(
FLAGS.xml_output_file,
os.environ['ABSLTEST_TEST_HELPER_EXPECTED_XML_OUTPUT_FILE'])
else:
raise unittest.SkipTest(
'Not asked to run: --test_id={}'.format(_TEST_ID.value))
def test_name_flag(self):
if _TEST_ID.value == 7:
print('Names in test_name_flag() are:', ' '.join(_NAME.value))
else:
raise unittest.SkipTest(
'Not asked to run: --test_id={}'.format(_TEST_ID.value))
class TempFileHelperTest(absltest.TestCase):
"""Helper test case for tempfile cleanup tests."""
tempfile_cleanup = absltest.TempFileCleanup[os.environ.get(
'ABSLTEST_TEST_HELPER_TEMPFILE_CLEANUP', 'SUCCESS')]
def test_failure(self):
self.create_tempfile('failure')
self.fail('expected failure')
def test_success(self):
self.create_tempfile('success')
def test_subtest_failure(self):
self.create_tempfile('parent')
with self.subTest('success'):
self.create_tempfile('successful_child')
with self.subTest('failure'):
self.create_tempfile('failed_child')
self.fail('expected failure')
def test_subtest_success(self):
self.create_tempfile('parent')
for i in range(2):
with self.subTest(f'success{i}'):
self.create_tempfile(f'child{i}')
def main(argv):
del argv # Unused.
print('Names in main() are:', ' '.join(_NAME.value))
absltest.main()
if __name__ == '__main__':
if os.environ.get('ABSLTEST_TEST_HELPER_USE_APP_RUN'):
app.run(main)
else:
absltest.main()
abseil-py-2.1.0/absl/testing/tests/absltest_test_helper_skipped.py 0000664 0000000 0000000 00000000536 14551576331 0025465 0 ustar 00root root 0000000 0000000 """Test helper for ExitCodeTest in absltest_test.py."""
import os
from absl.testing import absltest
if os.environ.get("ABSLTEST_TEST_HELPER_DEFINE_CLASS") == "1":
class MyTest(absltest.TestCase):
@absltest.skip("Skipped for testing the exit code behavior")
def test_foo(self):
pass
if __name__ == "__main__":
absltest.main()
abseil-py-2.1.0/absl/testing/tests/flagsaver_test.py 0000664 0000000 0000000 00000055646 14551576331 0022554 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for flagsaver."""
from absl import flags
from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
flags.DEFINE_string('flagsaver_test_flag0', 'unchanged0', 'flag to test with')
flags.DEFINE_string('flagsaver_test_flag1', 'unchanged1', 'flag to test with')
flags.DEFINE_string('flagsaver_test_validated_flag', None, 'flag to test with')
flags.register_validator('flagsaver_test_validated_flag', lambda x: not x)
flags.DEFINE_string('flagsaver_test_validated_flag1', None, 'flag to test with')
flags.DEFINE_string('flagsaver_test_validated_flag2', None, 'flag to test with')
INT_FLAG = flags.DEFINE_integer(
'flagsaver_test_int_flag', default=1, help='help')
STR_FLAG = flags.DEFINE_string(
'flagsaver_test_str_flag', default='str default', help='help')
MULTI_INT_FLAG = flags.DEFINE_multi_integer('flagsaver_test_multi_int_flag',
None, 'flag to test with')
@flags.multi_flags_validator(
('flagsaver_test_validated_flag1', 'flagsaver_test_validated_flag2'))
def validate_test_flags(flag_dict):
return (flag_dict['flagsaver_test_validated_flag1'] ==
flag_dict['flagsaver_test_validated_flag2'])
FLAGS = flags.FLAGS
@flags.validator('flagsaver_test_flag0')
def check_no_upper_case(value):
return value == value.lower()
class _TestError(Exception):
"""Exception class for use in these tests."""
class CommonUsageTest(absltest.TestCase):
"""These test cases cover the most common usages of flagsaver."""
def test_as_parsed_context_manager(self):
# Precondition check, we expect all the flags to start as their default.
self.assertEqual('str default', STR_FLAG.value)
self.assertFalse(STR_FLAG.present)
self.assertEqual(1, INT_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged1', FLAGS.flagsaver_test_flag1)
# Flagsaver will also save the state of flags that have been modified.
FLAGS.flagsaver_test_flag1 = 'outside flagsaver'
# Save all existing flag state, and set some flags as if they were parsed on
# the command line. Because of this, the new values must be provided as str,
# even if the flag type is something other than string.
with flagsaver.as_parsed(
(STR_FLAG, 'new string value'), # Override using flagholder object.
(INT_FLAG, '123'), # Override an int flag (NOTE: must specify as str).
flagsaver_test_flag0='new value', # Override using flag name.
):
# All the flags have their overridden values.
self.assertEqual('new string value', STR_FLAG.value)
self.assertTrue(STR_FLAG.present)
self.assertEqual(123, INT_FLAG.value)
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
# Even if we change other flags, they will reset on context exit.
FLAGS.flagsaver_test_flag1 = 'new value 1'
# The flags have all reset to their pre-flagsaver values.
self.assertEqual('str default', STR_FLAG.value)
self.assertFalse(STR_FLAG.present)
self.assertEqual(1, INT_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertEqual('outside flagsaver', FLAGS.flagsaver_test_flag1)
def test_as_parsed_decorator(self):
# flagsaver.as_parsed can also be used as a decorator.
@flagsaver.as_parsed((INT_FLAG, '123'))
def do_something_with_flags():
self.assertEqual(123, INT_FLAG.value)
self.assertTrue(INT_FLAG.present)
do_something_with_flags()
self.assertEqual(1, INT_FLAG.value)
self.assertFalse(INT_FLAG.present)
def test_flagsaver_flagsaver(self):
# If you don't want the flags to go through parsing, you can instead use
# flagsaver.flagsaver(). With this method, you provide the native python
# value you'd like the flags to take on. Otherwise it functions similar to
# flagsaver.as_parsed().
@flagsaver.flagsaver((INT_FLAG, 345))
def do_something_with_flags():
self.assertEqual(345, INT_FLAG.value)
# Note that because this flag was never parsed, it will not register as
# .present unless you manually set that attribute.
self.assertFalse(INT_FLAG.present)
# If you do chose to modify things about the flag (such as .present) those
# changes will still be cleaned up when flagsaver.flagsaver() exits.
INT_FLAG.present = True
self.assertEqual(1, INT_FLAG.value)
# flagsaver.flagsaver() restored INT_FLAG.present to the state it was in
# before entering the context.
self.assertFalse(INT_FLAG.present)
class SaveFlagValuesTest(absltest.TestCase):
"""Test flagsaver.save_flag_values() and flagsaver.restore_flag_values().
In this test, we insure that *all* properties of flags get restored. In other
tests we only try changing the flag value.
"""
def test_assign_value(self):
# First save the flag values.
saved_flag_values = flagsaver.save_flag_values()
# Now mutate the flag's value field and check that it changed.
FLAGS.flagsaver_test_flag0 = 'new value'
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
# Now restore the flag to its original value.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_set_default(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
# Now mutate the flag's default field and check that it changed.
FLAGS.set_default('flagsaver_test_flag0', 'new_default')
self.assertEqual('new_default', FLAGS['flagsaver_test_flag0'].default)
# Now restore the flag's default field.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].default)
def test_parse(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
# Sanity check (would fail if called with --flagsaver_test_flag0).
self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
# Now populate the flag and check that it changed.
FLAGS['flagsaver_test_flag0'].parse('new value')
self.assertEqual('new value', FLAGS['flagsaver_test_flag0'].value)
self.assertEqual(1, FLAGS['flagsaver_test_flag0'].present)
# Now restore the flag to its original value.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual('unchanged0', FLAGS['flagsaver_test_flag0'].value)
self.assertEqual(0, FLAGS['flagsaver_test_flag0'].present)
def test_assign_validators(self):
# First save the flag.
saved_flag_values = flagsaver.save_flag_values()
# Sanity check that a validator already exists.
self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 1)
original_validators = list(FLAGS['flagsaver_test_flag0'].validators)
def no_space(value):
return ' ' not in value
# Add a new validator.
flags.register_validator('flagsaver_test_flag0', no_space)
self.assertLen(FLAGS['flagsaver_test_flag0'].validators, 2)
# Now restore the flag to its original value.
flagsaver.restore_flag_values(saved_flag_values)
self.assertEqual(
original_validators, FLAGS['flagsaver_test_flag0'].validators
)
@parameterized.named_parameters(
dict(
testcase_name='flagsaver.flagsaver',
flagsaver_method=flagsaver.flagsaver,
),
dict(
testcase_name='flagsaver.as_parsed',
flagsaver_method=flagsaver.as_parsed,
),
)
class NoOverridesTest(parameterized.TestCase):
"""Test flagsaver.flagsaver and flagsaver.as_parsed without overrides."""
def test_context_manager_with_call(self, flagsaver_method):
with flagsaver_method():
FLAGS.flagsaver_test_flag0 = 'new value'
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_context_manager_with_exception(self, flagsaver_method):
with self.assertRaises(_TestError):
with flagsaver_method():
FLAGS.flagsaver_test_flag0 = 'new value'
# Simulate a failed test.
raise _TestError('something happened')
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_decorator_without_call(self, flagsaver_method):
@flagsaver_method
def mutate_flags():
FLAGS.flagsaver_test_flag0 = 'new value'
mutate_flags()
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_decorator_with_call(self, flagsaver_method):
@flagsaver_method()
def mutate_flags():
FLAGS.flagsaver_test_flag0 = 'new value'
mutate_flags()
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_decorator_with_exception(self, flagsaver_method):
@flagsaver_method()
def raise_exception():
FLAGS.flagsaver_test_flag0 = 'new value'
# Simulate a failed test.
raise _TestError('something happened')
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertRaises(_TestError, raise_exception)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
@parameterized.named_parameters(
dict(
testcase_name='flagsaver.flagsaver',
flagsaver_method=flagsaver.flagsaver,
),
dict(
testcase_name='flagsaver.as_parsed',
flagsaver_method=flagsaver.as_parsed,
),
)
class TestStringFlagOverrides(parameterized.TestCase):
"""Test flagsaver.flagsaver and flagsaver.as_parsed with string overrides.
Note that these tests can be parameterized because both .flagsaver and
.as_parsed expect a str input when overriding a string flag. For non-string
flags these two flagsaver methods have separate tests elsewhere in this file.
Each test is one class of overrides, executed twice. Once as a context
manager, and once as a decorator on a mutate_flags() method.
"""
def test_keyword_overrides(self, flagsaver_method):
# Context manager:
with flagsaver_method(flagsaver_test_flag0='new value'):
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
# Decorator:
@flagsaver_method(flagsaver_test_flag0='new value')
def mutate_flags():
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
mutate_flags()
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_flagholder_overrides(self, flagsaver_method):
with flagsaver_method((STR_FLAG, 'new value')):
self.assertEqual('new value', STR_FLAG.value)
self.assertEqual('str default', STR_FLAG.value)
@flagsaver_method((STR_FLAG, 'new value'))
def mutate_flags():
self.assertEqual('new value', STR_FLAG.value)
mutate_flags()
self.assertEqual('str default', STR_FLAG.value)
def test_keyword_and_flagholder_overrides(self, flagsaver_method):
with flagsaver_method(
(STR_FLAG, 'another value'), flagsaver_test_flag0='new value'
):
self.assertEqual('another value', STR_FLAG.value)
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
self.assertEqual('str default', STR_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
@flagsaver_method(
(STR_FLAG, 'another value'), flagsaver_test_flag0='new value'
)
def mutate_flags():
self.assertEqual('another value', STR_FLAG.value)
self.assertEqual('new value', FLAGS.flagsaver_test_flag0)
mutate_flags()
self.assertEqual('str default', STR_FLAG.value)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
def test_cross_validated_overrides_set_together(self, flagsaver_method):
# When the flags are set in the same flagsaver call their validators will
# be triggered only once the setting is done.
with flagsaver_method(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='new_value',
):
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
@flagsaver_method(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='new_value',
)
def mutate_flags():
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag1)
self.assertEqual('new_value', FLAGS.flagsaver_test_validated_flag2)
mutate_flags()
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_cross_validated_overrides_set_badly(self, flagsaver_method):
# Different values should violate the validator.
with self.assertRaisesRegex(
flags.IllegalFlagValueError, 'Flag validation failed'
):
with flagsaver_method(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='other_value',
):
pass
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
@flagsaver_method(
flagsaver_test_validated_flag1='new_value',
flagsaver_test_validated_flag2='other_value',
)
def mutate_flags():
pass
self.assertRaisesRegex(
flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags
)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_cross_validated_overrides_set_separately(self, flagsaver_method):
# Setting just one flag will trip the validator as well.
with self.assertRaisesRegex(
flags.IllegalFlagValueError, 'Flag validation failed'
):
with flagsaver_method(flagsaver_test_validated_flag1='new_value'):
pass
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
@flagsaver_method(flagsaver_test_validated_flag1='new_value')
def mutate_flags():
pass
self.assertRaisesRegex(
flags.IllegalFlagValueError, 'Flag validation failed', mutate_flags
)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag1)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag2)
def test_validation_exception(self, flagsaver_method):
with self.assertRaises(flags.IllegalFlagValueError):
with flagsaver_method(
flagsaver_test_flag0='new value',
flagsaver_test_validated_flag='new value',
):
pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
@flagsaver_method(
flagsaver_test_flag0='new value',
flagsaver_test_validated_flag='new value',
)
def mutate_flags():
pass
self.assertRaises(flags.IllegalFlagValueError, mutate_flags)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
self.assertIsNone(FLAGS.flagsaver_test_validated_flag)
def test_unknown_flag_raises_exception(self, flagsaver_method):
self.assertNotIn('this_flag_does_not_exist', FLAGS)
# Flagsaver raises an error when trying to override a non-existent flag.
with self.assertRaises(flags.UnrecognizedFlagError):
with flagsaver_method(
flagsaver_test_flag0='new value', this_flag_does_not_exist='new value'
):
pass
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
@flagsaver_method(
flagsaver_test_flag0='new value', this_flag_does_not_exist='new value'
)
def mutate_flags():
pass
self.assertRaises(flags.UnrecognizedFlagError, mutate_flags)
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
# Make sure flagsaver didn't create the flag at any point.
self.assertNotIn('this_flag_does_not_exist', FLAGS)
class AsParsedTest(absltest.TestCase):
def test_parse_context_manager_sets_present_and_using_default(self):
self.assertFalse(INT_FLAG.present)
self.assertFalse(STR_FLAG.present)
# Note that .using_default_value isn't available on the FlagHolder directly.
self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
with flagsaver.as_parsed((INT_FLAG, '123'),
flagsaver_test_str_flag='new value'):
self.assertTrue(INT_FLAG.present)
self.assertTrue(STR_FLAG.present)
self.assertFalse(FLAGS[INT_FLAG.name].using_default_value)
self.assertFalse(FLAGS[STR_FLAG.name].using_default_value)
self.assertFalse(INT_FLAG.present)
self.assertFalse(STR_FLAG.present)
self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
def test_parse_decorator_sets_present_and_using_default(self):
self.assertFalse(INT_FLAG.present)
self.assertFalse(STR_FLAG.present)
# Note that .using_default_value isn't available on the FlagHolder directly.
self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
@flagsaver.as_parsed((INT_FLAG, '123'), flagsaver_test_str_flag='new value')
def some_func():
self.assertTrue(INT_FLAG.present)
self.assertTrue(STR_FLAG.present)
self.assertFalse(FLAGS[INT_FLAG.name].using_default_value)
self.assertFalse(FLAGS[STR_FLAG.name].using_default_value)
some_func()
self.assertFalse(INT_FLAG.present)
self.assertFalse(STR_FLAG.present)
self.assertTrue(FLAGS[INT_FLAG.name].using_default_value)
self.assertTrue(FLAGS[STR_FLAG.name].using_default_value)
def test_parse_decorator_with_multi_int_flag(self):
self.assertFalse(MULTI_INT_FLAG.present)
self.assertIsNone(MULTI_INT_FLAG.value)
@flagsaver.as_parsed((MULTI_INT_FLAG, ['123', '456']))
def assert_flags_updated():
self.assertTrue(MULTI_INT_FLAG.present)
self.assertCountEqual([123, 456], MULTI_INT_FLAG.value)
assert_flags_updated()
self.assertFalse(MULTI_INT_FLAG.present)
self.assertIsNone(MULTI_INT_FLAG.value)
def test_parse_raises_type_error(self):
with self.assertRaisesRegex(
TypeError,
r'flagsaver\.as_parsed\(\) cannot parse flagsaver_test_int_flag\. '
r'Expected a single string or sequence of strings but .*int.* was '
r'provided\.'):
manager = flagsaver.as_parsed(flagsaver_test_int_flag=123) # pytype: disable=wrong-arg-types
del manager
class SetUpTearDownTest(absltest.TestCase):
"""Example using a single flagsaver in setUp."""
def setUp(self):
super().setUp()
self.saved_flag_values = flagsaver.save_flag_values()
def tearDown(self):
super().tearDown()
flagsaver.restore_flag_values(self.saved_flag_values)
def test_mutate1(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
def test_mutate2(self):
# Even though other test cases change the flag, it should be
# restored to 'unchanged0' if the flagsaver is working.
self.assertEqual('unchanged0', FLAGS.flagsaver_test_flag0)
FLAGS.flagsaver_test_flag0 = 'changed0'
@parameterized.named_parameters(
dict(
testcase_name='flagsaver.flagsaver',
flagsaver_method=flagsaver.flagsaver,
),
dict(
testcase_name='flagsaver.as_parsed',
flagsaver_method=flagsaver.as_parsed,
),
)
class BadUsageTest(parameterized.TestCase):
"""Tests that improper usage (such as decorating a class) raise errors."""
def test_flag_saver_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
@flagsaver_method
class FooTest(absltest.TestCase):
def test_tautology(self):
pass
del FooTest
def test_flag_saver_call_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
@flagsaver_method()
class FooTest(absltest.TestCase):
def test_tautology(self):
pass
del FooTest
def test_flag_saver_with_overrides_on_class(self, flagsaver_method):
with self.assertRaises(TypeError):
# WRONG. Don't do this.
# Consider the correct usage example in FlagSaverSetUpTearDownUsageTest.
@flagsaver_method(foo='bar')
class FooTest(absltest.TestCase):
def test_tautology(self):
pass
del FooTest
def test_multiple_positional_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
func_b = lambda: None
flagsaver_method(func_a, func_b)
def test_both_positional_and_keyword_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
flagsaver_method(func_a, flagsaver_test_flag0='new value')
def test_duplicate_holder_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
flagsaver_method((INT_FLAG, 45), (INT_FLAG, 45))
def test_duplicate_holder_and_kw_parameter(self, flagsaver_method):
with self.assertRaises(ValueError):
flagsaver_method((INT_FLAG, 45), **{INT_FLAG.name: 45})
def test_both_positional_and_holder_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
flagsaver_method(func_a, (INT_FLAG, 45))
def test_holder_parameters_wrong_shape(self, flagsaver_method):
with self.assertRaises(ValueError):
flagsaver_method(INT_FLAG)
def test_holder_parameters_tuple_too_long(self, flagsaver_method):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
flagsaver_method((INT_FLAG, 4, 5))
def test_holder_parameters_tuple_wrong_type(self, flagsaver_method):
with self.assertRaises(ValueError):
# Even if it is a bool flag, it should be a tuple
flagsaver_method((4, INT_FLAG))
def test_both_wrong_positional_parameters(self, flagsaver_method):
with self.assertRaises(ValueError):
func_a = lambda: None
flagsaver_method(func_a, STR_FLAG, '45')
def test_context_manager_no_call(self, flagsaver_method):
# The exact exception that's raised appears to be system specific.
with self.assertRaises((AttributeError, TypeError)):
# Wrong. You must call the flagsaver method before using it as a CM.
with flagsaver_method:
# We don't expect to get here. A type error should happen when
# attempting to enter the context manager.
pass
if __name__ == '__main__':
absltest.main()
abseil-py-2.1.0/absl/testing/tests/parameterized_test.py 0000664 0000000 0000000 00000105353 14551576331 0023425 0 ustar 00root root 0000000 0000000 # Copyright 2017 The Abseil Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for absl.testing.parameterized."""
from collections import abc
import os
import sys
import unittest
from absl.testing import absltest
from absl.testing import parameterized
class MyOwnClass(object):
pass
def dummy_decorator(method):
def decorated(*args, **kwargs):
return method(*args, **kwargs)
return decorated
def dict_decorator(key, value):
"""Sample implementation of a chained decorator.
Sets a single field in a dict on a test with a dict parameter.
Uses the exposed '_ParameterizedTestIter.testcases' field to
modify arguments from previous decorators to allow decorator chains.
Args:
key: key to map to
value: value to set
Returns:
The test decorator
"""
def decorator(test_method):
# If decorating result of another dict_decorator
if isinstance(test_method, abc.Iterable):
actual_tests = []
for old_test in test_method.testcases:
# each test is a ('test_suffix', dict) tuple
new_dict = old_test[1].copy()
new_dict[key] = value
test_suffix = '%s_%s_%s' % (old_test[0], key, value)
actual_tests.append((test_suffix, new_dict))
test_method.testcases = actual_tests
return test_method
else:
test_suffix = '_%s_%s' % (key, value)
tests_to_make = ((test_suffix, {key: value}),)
# 'test_method' here is the original test method
return parameterized.named_parameters(*tests_to_make)(test_method)
return decorator
class ParameterizedTestsTest(absltest.TestCase):
# The test testcases are nested so they're not
# picked up by the normal test case loader code.
class GoodAdditionParams(parameterized.TestCase):
@parameterized.parameters((1, 2, 3), (4, 5, 9))
def test_addition(self, op1, op2, result):
self.arguments = (op1, op2, result)
self.assertEqual(result, op1 + op2)
# This class does not inherit from TestCase.
class BadAdditionParams(absltest.TestCase):
@parameterized.parameters((1, 2, 3), (4, 5, 9))
def test_addition(self, op1, op2, result):
pass # Always passes, but not called w/out TestCase.
class MixedAdditionParams(parameterized.TestCase):
@parameterized.parameters((1, 2, 1), (4, 5, 9))
def test_addition(self, op1, op2, result):
self.arguments = (op1, op2, result)
self.assertEqual(result, op1 + op2)
class DictionaryArguments(parameterized.TestCase):
@parameterized.parameters(
{'op1': 1, 'op2': 2, 'result': 3}, {'op1': 4, 'op2': 5, 'result': 9}
)
def test_addition(self, op1, op2, result):
self.assertEqual(result, op1 + op2)
class NoParameterizedTests(parameterized.TestCase):
# iterable member with non-matching name
a = 'BCD'
# member with matching name, but not a generator
testInstanceMember = None # pylint: disable=invalid-name
test_instance_member = None
# member with a matching name and iterator, but not a generator
testString = 'foo' # pylint: disable=invalid-name
test_string = 'foo'
# generator, but no matching name
def someGenerator(self): # pylint: disable=invalid-name
yield
yield
yield
def some_generator(self):
yield
yield
yield
# Generator function, but not a generator instance.
def testGenerator(self):
yield
yield
yield
def test_generator(self):
yield
yield
yield
def testNormal(self):
self.assertEqual(3, 1 + 2)
def test_normal(self):
self.assertEqual(3, 1 + 2)
class ArgumentsWithAddresses(parameterized.TestCase):
@parameterized.parameters(
(object(),),
(MyOwnClass(),),
)
def test_something(self, case):
pass
class CamelCaseNamedTests(parameterized.TestCase):
@parameterized.named_parameters(
('Interesting', 0),
)
def testSingle(self, case):
pass
@parameterized.named_parameters(
{'testcase_name': 'Interesting', 'case': 0},
)
def testDictSingle(self, case):
pass
@parameterized.named_parameters(
('Interesting', 0),
('Boring', 1),
)
def testSomething(self, case):
pass
@parameterized.named_parameters(
{'testcase_name': 'Interesting', 'case': 0},
{'testcase_name': 'Boring', 'case': 1},
)
def testDictSomething(self, case):
pass
@parameterized.named_parameters(
{'testcase_name': 'Interesting', 'case': 0},
('Boring', 1),
)
def testMixedSomething(self, case):
pass
def testWithoutParameters(self):
pass
class NamedTests(parameterized.TestCase):
"""Example tests using PEP-8 style names instead of camel-case."""
@parameterized.named_parameters(
('interesting', 0),
)
def test_single(self, case):
pass
@parameterized.named_parameters(
{'testcase_name': 'interesting', 'case': 0},
)
def test_dict_single(self, case):
pass
@parameterized.named_parameters(
('interesting', 0),
('boring', 1),
)
def test_something(self, case):
pass
@parameterized.named_parameters(
{'testcase_name': 'interesting', 'case': 0},
{'testcase_name': 'boring', 'case': 1},
)
def test_dict_something(self, case):
pass
@parameterized.named_parameters(
{'testcase_name': 'interesting', 'case': 0},
('boring', 1),
)
def test_mixed_something(self, case):
pass
def test_without_parameters(self):
pass
class ChainedTests(parameterized.TestCase):
@dict_decorator('cone', 'waffle')
@dict_decorator('flavor', 'strawberry')
def test_chained(self, dictionary):
self.assertDictEqual(
dictionary, {'cone': 'waffle', 'flavor': 'strawberry'}
)
class SingletonListExtraction(parameterized.TestCase):
@parameterized.parameters((i, i * 2) for i in range(10))
def test_something(self, unused_1, unused_2):
pass
class SingletonArgumentExtraction(parameterized.TestCase):
@parameterized.parameters(1, 2, 3, 4, 5, 6)
def test_numbers(self, unused_1):
pass
@parameterized.parameters('foo', 'bar', 'baz')
def test_strings(self, unused_1):
pass
class SingletonDictArgument(parameterized.TestCase):
@parameterized.parameters({'op1': 1, 'op2': 2})
def test_something(self, op1, op2):
del op1, op2
@parameterized.parameters((1, 2, 3), (4, 5, 9))
class DecoratedClass(parameterized.TestCase):
def test_add(self, arg1, arg2, arg3):
self.assertEqual(arg1 + arg2, arg3)
def test_subtract_fail(self, arg1, arg2, arg3):
self.assertEqual(arg3 + arg2, arg1)
@parameterized.parameters(
(a, b, a + b) for a in range(1, 5) for b in range(1, 5)
)
class GeneratorDecoratedClass(parameterized.TestCase):
def test_add(self, arg1, arg2, arg3):
self.assertEqual(arg1 + arg2, arg3)
def test_subtract_fail(self, arg1, arg2, arg3):
self.assertEqual(arg3 + arg2, arg1)
@parameterized.parameters(
(1, 2, 3),
(4, 5, 9),
)
class DecoratedBareClass(absltest.TestCase):
def test_add(self, arg1, arg2, arg3):
self.assertEqual(arg1 + arg2, arg3)
class OtherDecoratorUnnamed(parameterized.TestCase):
@dummy_decorator
@parameterized.parameters((1), (2))
def test_other_then_parameterized(self, arg1):
pass
@parameterized.parameters((1), (2))
@dummy_decorator
def test_parameterized_then_other(self, arg1):
pass
class OtherDecoratorNamed(parameterized.TestCase):
@dummy_decorator
@parameterized.named_parameters(('a', 1), ('b', 2))
def test_other_then_parameterized(self, arg1):
pass
@parameterized.named_parameters(('a', 1), ('b', 2))
@dummy_decorator
def test_parameterized_then_other(self, arg1):
pass
class OtherDecoratorNamedWithDict(parameterized.TestCase):
@dummy_decorator
@parameterized.named_parameters(
{'testcase_name': 'a', 'arg1': 1}, {'testcase_name': 'b', 'arg1': 2}
)
def test_other_then_parameterized(self, arg1):
pass
@parameterized.named_parameters(
{'testcase_name': 'a', 'arg1': 1}, {'testcase_name': 'b', 'arg1': 2}
)
@dummy_decorator
def test_parameterized_then_other(self, arg1):
pass
class UniqueDescriptiveNamesTest(parameterized.TestCase):
@parameterized.parameters(13, 13)
def test_normal(self, number):
del number
class MultiGeneratorsTestCase(parameterized.TestCase):
@parameterized.parameters((i for i in (1, 2, 3)), (i for i in (3, 2, 1)))
def test_sum(self, a, b, c):
self.assertEqual(6, sum([a, b, c]))
class NamedParametersReusableTestCase(parameterized.TestCase):
named_params_a = (
{'testcase_name': 'dict_a', 'unused_obj': 0},
('list_a', 1),
)
named_params_b = (
{'testcase_name': 'dict_b', 'unused_obj': 2},
('list_b', 3),
)
named_params_c = (
{'testcase_name': 'dict_c', 'unused_obj': 4},
('list_b', 5),
)
@parameterized.named_parameters(*(named_params_a + named_params_b))
def testSomething(self, unused_obj):
pass
@parameterized.named_parameters(*(named_params_a + named_params_c))
def testSomethingElse(self, unused_obj):
pass
class SuperclassTestCase(parameterized.TestCase):
@parameterized.parameters('foo', 'bar')
def test_name(self, name):
del name
class SubclassTestCase(SuperclassTestCase):
pass
@unittest.skipIf(
(sys.version_info[:2] == (3, 7) and sys.version_info[2] in {0, 1, 2}),
'Python 3.7.0 to 3.7.2 have a bug that breaks this test, see '
'https://bugs.python.org/issue35767',
)
def test_missing_inheritance(self):
ts = unittest.makeSuite(self.BadAdditionParams)
self.assertEqual(1, ts.countTestCases())
res = unittest.TestResult()
ts.run(res)
self.assertEqual(1, res.testsRun)
self.assertFalse(res.wasSuccessful())
self.assertIn('without having inherited', str(res.errors[0]))
def test_correct_extraction_numbers(self):
ts = unittest.makeSuite(self.GoodAdditionParams)
self.assertEqual(2, ts.countTestCases())
def test_successful_execution(self):
ts = unittest.makeSuite(self.GoodAdditionParams)
res = unittest.TestResult()
ts.run(res)
self.assertEqual(2, res.testsRun)
self.assertTrue(res.wasSuccessful())
def test_correct_arguments(self):
ts = unittest.makeSuite(self.GoodAdditionParams)
res = unittest.TestResult()
params = set([(1, 2, 3), (4, 5, 9)])
for test in ts:
test(res)
self.assertIn(test.arguments, params)
params.remove(test.arguments)
self.assertEmpty(params)
def test_recorded_failures(self):
ts = unittest.makeSuite(self.MixedAdditionParams)
self.assertEqual(2, ts.countTestCases())
res = unittest.TestResult()
ts.run(res)
self.assertEqual(2, res.testsRun)
self.assertFalse(res.wasSuccessful())
self.assertLen(res.failures, 1)
self.assertEmpty(res.errors)
def test_short_description(self):
ts = unittest.makeSuite(self.GoodAdditionParams)
short_desc = list(ts)[0].shortDescription()
location = unittest.util.strclass(self.GoodAdditionParams).replace(
'__main__.', ''
)
expected = (
'{}.test_addition0 (1, 2, 3)\n'.format(location)
+ 'test_addition(1, 2, 3)'
)
self.assertEqual(expected, short_desc)
def test_short_description_addresses_removed(self):
ts = unittest.makeSuite(self.ArgumentsWithAddresses)
short_desc = list(ts)[0].shortDescription().split('\n')
self.assertEqual('test_something(