ml_collections-0.1.1/0000750000175000017500000000000014174510450014041 5ustar nileshnileshml_collections-0.1.1/README.md0000640000175000017500000004366314174507605015345 0ustar nileshnilesh# ML Collections ML Collections is a library of Python Collections designed for ML use cases. [![Documentation Status](https://readthedocs.org/projects/ml-collections/badge/?version=latest)](https://ml-collections.readthedocs.io/en/latest/?badge=latest) [![PyPI version](https://badge.fury.io/py/ml-collections.svg)](https://badge.fury.io/py/ml-collections) [![Build Status](https://github.com/google/ml_collections/workflows/Python%20package/badge.svg)](https://github.com/google/ml_collections/actions?query=workflow%3A%22Python+package%22) ## ConfigDict The two classes called `ConfigDict` and `FrozenConfigDict` are "dict-like" data structures with dot access to nested elements. Together, they are supposed to be used as a main way of expressing configurations of experiments and models. This document describes example usage of `ConfigDict`, `FrozenConfigDict`, `FieldReference`. ### Features * Dot-based access to fields. * Locking mechanism to prevent spelling mistakes. * Lazy computation. * FrozenConfigDict() class which is immutable and hashable. * Type safety. * "Did you mean" functionality. * Human readable printing (with valid references and cycles), using valid YAML format. * Fields can be passed as keyword arguments using the `**` operator. * There are two exceptions to the strong type-safety of the ConfigDict. `int` values can be passed in to fields of type `float`. In such a case, the value is type-converted to a `float` before being stored. Similarly, all string types (including Unicode strings) can be stored in fields of type `str` or `unicode`. ### Basic Usage ```python import ml_collections cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg.integer_field) # Prints 123. print(cfg['integer_field']) # Prints 123 as well. try: cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. except TypeError as e: print(e) cfg.float_field = 12 # Works: `Int` types can be assigned to `Float`. cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. print(cfg) ``` ### FrozenConfigDict A `FrozenConfigDict`is an immutable, hashable type of `ConfigDict`: ```python import ml_collections initial_dictionary = { 'int': 1, 'list': [1, 2], 'tuple': (1, 2, 3), 'set': {1, 2, 3, 4}, 'dict_tuple_list': {'tuple_list': ([1, 2], 3)} } cfg = ml_collections.ConfigDict(initial_dictionary) frozen_dict = ml_collections.FrozenConfigDict(initial_dictionary) print(frozen_dict.tuple) # Prints tuple (1, 2, 3) print(frozen_dict.list) # Prints tuple (1, 2) print(frozen_dict.set) # Prints frozenset {1, 2, 3, 4} print(frozen_dict.dict_tuple_list.tuple_list[0]) # Prints tuple (1, 2) frozen_cfg = ml_collections.FrozenConfigDict(cfg) print(frozen_cfg == frozen_dict) # True print(hash(frozen_cfg) == hash(frozen_dict)) # True try: frozen_dict.int = 2 # Raises TypeError as FrozenConfigDict is immutable. except AttributeError as e: print(e) # Converting between `FrozenConfigDict` and `ConfigDict`: thawed_frozen_cfg = ml_collections.ConfigDict(frozen_dict) print(thawed_frozen_cfg == cfg) # True frozen_cfg_to_cfg = frozen_dict.as_configdict() print(frozen_cfg_to_cfg == cfg) # True ``` ### FieldReferences and placeholders A `FieldReference` is useful for having multiple fields use the same value. It can also be used for [lazy computation](#lazy-computation). You can use `placeholder()` as a shortcut to create a `FieldReference` (field) with a `None` default value. This is useful if a program uses optional configuration fields. ```python import ml_collections from ml_collections.config_dict import config_dict placeholder = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict() cfg.placeholder = placeholder cfg.optional = config_dict.placeholder(int) cfg.nested = ml_collections.ConfigDict() cfg.nested.placeholder = placeholder try: cfg.optional = 'tom' # Raises Type error as this field is an integer. except TypeError as e: print(e) cfg.optional = 1555 # Works fine. cfg.placeholder = 1 # Changes the value of both placeholder and # nested.placeholder fields. print(cfg) ``` Note that the indirection provided by `FieldReference`s will be lost if accessed through a `ConfigDict`. ```python import ml_collections placeholder = ml_collections.FieldReference(0) cfg.field1 = placeholder cfg.field2 = placeholder # This field will be tied to cfg.field1. cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. ``` ### Lazy computation Using a `FieldReference` in a standard operation (addition, subtraction, multiplication, etc...) will return another `FieldReference` that points to the original's value. You can use `FieldReference.get()` to execute the operations and get the reference's computed value, and `FieldReference.set()` to change the original reference's value. ```python import ml_collections ref = ml_collections.FieldReference(1) print(ref.get()) # Prints 1 add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 # Addition is lazily computed for FieldReferences so changing ref will change # the value that is used to compute add_ten. ref.set(5) print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 ``` If a `FieldReference` has `None` as its original value, or any operation has an argument of `None`, then the lazy computation will evaluate to `None`. We can also use fields in a `ConfigDict` in lazy computation. In this case a field will only be lazily evaluated if `ConfigDict.get_ref()` is used to get it. ```python import ml_collections config = ml_collections.ConfigDict() config.reference_field = ml_collections.FieldReference(1) config.integer_field = 2 config.float_field = 2.5 # No lazy evaluatuations because we didn't use get_ref() config.no_lazy = config.integer_field * config.float_field # This will lazily evaluate ONLY config.integer_field config.lazy_integer = config.get_ref('integer_field') * config.float_field # This will lazily evaluate ONLY config.float_field config.lazy_float = config.integer_field * config.get_ref('float_field') # This will lazily evaluate BOTH config.integer_field and config.float_Field config.lazy_both = (config.get_ref('integer_field') * config.get_ref('float_field')) config.integer_field = 3 print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value print(config.lazy_integer) # Prints 7.5 config.float_field = 3.5 print(config.lazy_float) # Prints 7.0 print(config.lazy_both) # Prints 10.5 ``` #### Changing lazily computed values Lazily computed values in a ConfigDict can be overridden in the same way as regular values. The reference to the `FieldReference` used for the lazy computation will be lost and all computations downstream in the reference graph will use the new value. ```python import ml_collections config = ml_collections.ConfigDict() config.reference = 1 config.reference_0 = config.get_ref('reference') + 10 config.reference_1 = config.get_ref('reference') + 20 config.reference_1_0 = config.get_ref('reference_1') + 100 print(config.reference) # Prints 1. print(config.reference_0) # Prints 11. print(config.reference_1) # Prints 21. print(config.reference_1_0) # Prints 121. config.reference_1 = 30 print(config.reference) # Prints 1 (unchanged). print(config.reference_0) # Prints 11 (unchanged). print(config.reference_1) # Prints 30. print(config.reference_1_0) # Prints 130. ``` #### Cycles You cannot create cycles using references. Fortunately [the only way](#changing-lazily-computed-values) to create a cycle is by assigning a computed field to one that *is not* the result of computation. This is forbidden: ```python import ml_collections from ml_collections.config_dict import config_dict config = ml_collections.ConfigDict() config.integer_field = 1 config.bigger_integer_field = config.get_ref('integer_field') + 10 try: # Raises a MutabilityError because setting config.integer_field would # cause a cycle. config.integer_field = config.get_ref('bigger_integer_field') + 2 except config_dict.MutabilityError as e: print(e) ``` ### Advanced usage Here are some more advanced examples showing lazy computation with different operators and data types. ```python import ml_collections config = ml_collections.ConfigDict() config.float_field = 12.6 config.integer_field = 123 config.list_field = [0, 1, 2] config.float_multiply_field = config.get_ref('float_field') * 3 print(config.float_multiply_field) # Prints 37.8 config.float_field = 10.0 print(config.float_multiply_field) # Prints 30.0 config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] config.list_field = [-1] print(config.longer_list_field) # Prints [-1, 3, 4, 5] # Both operands can be references config.ref_subtraction = ( config.get_ref('float_field') - config.get_ref('integer_field')) print(config.ref_subtraction) # Prints -113.0 config.integer_field = 10 print(config.ref_subtraction) # Prints 0.0 ``` ### Equality checking You can use `==` and `.eq_as_configdict()` to check equality among `ConfigDict` and `FrozenConfigDict` objects. ```python import ml_collections dict_1 = {'list': [1, 2]} dict_2 = {'list': (1, 2)} cfg_1 = ml_collections.ConfigDict(dict_1) frozen_cfg_1 = ml_collections.FrozenConfigDict(dict_1) frozen_cfg_2 = ml_collections.FrozenConfigDict(dict_2) # True because FrozenConfigDict converts lists to tuples print(frozen_cfg_1.items() == frozen_cfg_2.items()) # False because == distinguishes the underlying difference print(frozen_cfg_1 == frozen_cfg_2) # False because == distinguishes these types print(frozen_cfg_1 == cfg_1) # But eq_as_configdict() treats both as ConfigDict, so these are True: print(frozen_cfg_1.eq_as_configdict(cfg_1)) print(cfg_1.eq_as_configdict(frozen_cfg_1)) ``` ### Equality checking with lazy computation Equality checks see if the computed values are the same. Equality is satisfied if two sets of computations are different as long as they result in the same value. ```python import ml_collections cfg_1 = ml_collections.ConfigDict() cfg_1.a = 1 cfg_1.b = cfg_1.get_ref('a') + 2 cfg_2 = ml_collections.ConfigDict() cfg_2.a = 1 cfg_2.b = cfg_2.get_ref('a') * 3 # True because all computed values are the same print(cfg_1 == cfg_2) ``` ### Locking and copying Here is an example with `lock()` and `deepcopy()`: ```python import copy import ml_collections cfg = ml_collections.ConfigDict() cfg.integer_field = 123 # Locking prohibits the addition and deletion of new fields but allows # modification of existing values. cfg.lock() try: cfg.integer_field = 124 # Raises AttributeError and suggests valid field. except AttributeError as e: print(e) with cfg.unlocked(): cfg.integer_field = 1555 # Works fine too. # Get a copy of the config dict. new_cfg = copy.deepcopy(cfg) new_cfg.integer_field = -123 # Works fine. print(cfg) ``` ### Dictionary attributes and initialization ```python import ml_collections referenced_dict = {'inner_float': 3.14} d = { 'referenced_dict_1': referenced_dict, 'referenced_dict_2': referenced_dict, 'list_containing_dict': [{'key': 'value'}], } # We can initialize on a dictionary cfg = ml_collections.ConfigDict(d) # Reference structure is preserved print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2)) # True # And the dict attributes have been converted to ConfigDict print(type(cfg.referenced_dict_1)) # ConfigDict # However, the initialization does not look inside of lists, so dicts inside # lists are not converted to ConfigDict print(type(cfg.list_containing_dict[0])) # dict ``` ### More Examples For more examples, take a look at [`ml_collections/config_dict/examples/`](https://github.com/google/ml_collections/tree/master/ml_collections/config_dict/examples) For examples and gotchas specifically about initializing a ConfigDict, see [`ml_collections/config_dict/examples/config_dict_initialization.py`](https://github.com/google/ml_collections/blob/master/ml_collections/config_dict/examples/config_dict_initialization.py). ## Config Flags This library adds flag definitions to `absl.flags` to handle config files. It does not wrap `absl.flags` so if using any standard flag definitions alongside config file flags, users must also import `absl.flags`. Currently, this module adds two new flag types, namely `DEFINE_config_file` which accepts a path to a Python file that generates a configuration, and `DEFINE_config_dict` which accepts a configuration directly. Configurations are dict-like structures (see [ConfigDict](#configdict)) whose nested elements can be overridden using special command-line flags. See the examples below for more details. ### Usage Use `ml_collections.config_flags` alongside `absl.flags`. For example: `script.py`: ```python from absl import app from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file('my_config') def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run(main) ``` `config.py`: ```python # Note that this is a valid Python script. # get_config() can return an arbitrary dict-like object. However, it is advised # to use ml_collections.ConfigDict. # See ml_collections/config_dict/examples/config_dict_basic.py import ml_collections def get_config(): config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) return config ``` Now, after running: ```bash python script.py --my_config=config.py \ --my_config.field1=8 \ --my_config.nested.field=2.1 \ --my_config.tuple='(1, 2, (1, 2))' ``` we get: ``` field1: 8 field2: tom nested: field: 2.1 tuple: !!python/tuple - 1 - 2 - !!python/tuple - 1 - 2 ``` Usage of `DEFINE_config_dict` is similar to `DEFINE_config_file`, the main difference is the configuration is defined in `script.py` instead of in a separate file. `script.py`: ```python from absl import app from absl import flags import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) FLAGS = flags.FLAGS config_flags.DEFINE_config_dict('my_config', config) def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run() ``` `config_file` flags are compatible with the command-line flag syntax. All the following options are supported for non-boolean values in configurations: * `-(-)config.field=value` * `-(-)config.field value` Options for boolean values are slightly different: * `-(-)config.boolean_field`: set boolean value to True. * `-(-)noconfig.boolean_field`: set boolean value to False. * `-(-)config.boolean_field=value`: `value` is `true`, `false`, `True` or `False`. Note that `-(-)config.boolean_field value` is not supported. ### Parameterising the get_config() function It's sometimes useful to be able to pass parameters into `get_config`, and change what is returned based on this configuration. One example is if you are grid searching over parameters which have a different hierarchical structure - the flag needs to be present in the resulting ConfigDict. It would be possible to include the union of all possible leaf values in your ConfigDict, but this produces a confusing config result as you have to remember which parameters will actually have an effect and which won't. A better system is to pass some configuration, indicating which structure of ConfigDict should be returned. An example is the following config file: ```python import ml_collections def get_config(config_string): possible_structures = { 'linear': ml_collections.ConfigDict({ 'model_constructor': 'snt.Linear', 'model_config': ml_collections.ConfigDict({ 'output_size': 42, }), 'lstm': ml_collections.ConfigDict({ 'model_constructor': 'snt.LSTM', 'model_config': ml_collections.ConfigDict({ 'hidden_size': 108, }) }) } return possible_structures[config_string] ``` The value of `config_string` will be anything that is to the right of the first colon in the config file path, if one exists. If no colon exists, no value is passed to `get_config` (producing a TypeError if `get_config` expects a value.) The above example can be run like: ```bash python script.py -- --config=path_to_config.py:linear \ --config.model_config.output_size=256 ``` or like: ```bash python script.py -- --config=path_to_config.py:lstm \ --config.model_config.hidden_size=512 ``` ### Additional features * Loads any valid python script which defines `get_config()` function returning any python object. * Automatic locking of the loaded object, if the loaded object defines a callable `.lock()` method. * Supports command-line overriding of arbitrarily nested values in dict-like objects (with key/attribute based getters/setters) of the following types: * `types.IntType` (integer) * `types.FloatType` (float) * `types.BooleanType` (bool) * `types.StringType` (string) * `types.TupleType` (tuple) * Overriding is type safe. * Overriding of `TupleType` can be done by passing in the `tuple` as a string (see the example in the [Usage](#usage) section). * The overriding `tuple` object can be of a different size and have different types than the original. Nested tuples are also supported. ## Authors * Sergio Gómez Colmenarejo - sergomez@google.com * Wojciech Marian Czarnecki - lejlot@google.com * Nicholas Watters * Mohit Reddy - mohitreddy@google.com ml_collections-0.1.1/PKG-INFO0000640000175000017500000004546714174510450015157 0ustar nileshnileshMetadata-Version: 2.1 Name: ml_collections Version: 0.1.1 Summary: ML Collections is a library of Python collections designed for ML usecases. Home-page: https://github.com/google/ml_collections Author: ML Collections Authors Author-email: ml-collections@google.com License: Apache 2.0 Platform: UNKNOWN Classifier: Development Status :: 4 - Beta Classifier: Intended Audience :: Developers Classifier: Intended Audience :: Science/Research Classifier: License :: OSI Approved :: Apache Software License Classifier: Programming Language :: Python Classifier: Topic :: Scientific/Engineering Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence Classifier: Topic :: Software Development :: Libraries Classifier: Topic :: Software Development :: Libraries :: Python Modules Requires-Python: >=2.6 Description-Content-Type: text/markdown License-File: LICENSE License-File: AUTHORS # ML Collections ML Collections is a library of Python Collections designed for ML use cases. [![Documentation Status](https://readthedocs.org/projects/ml-collections/badge/?version=latest)](https://ml-collections.readthedocs.io/en/latest/?badge=latest) [![PyPI version](https://badge.fury.io/py/ml-collections.svg)](https://badge.fury.io/py/ml-collections) [![Build Status](https://github.com/google/ml_collections/workflows/Python%20package/badge.svg)](https://github.com/google/ml_collections/actions?query=workflow%3A%22Python+package%22) ## ConfigDict The two classes called `ConfigDict` and `FrozenConfigDict` are "dict-like" data structures with dot access to nested elements. Together, they are supposed to be used as a main way of expressing configurations of experiments and models. This document describes example usage of `ConfigDict`, `FrozenConfigDict`, `FieldReference`. ### Features * Dot-based access to fields. * Locking mechanism to prevent spelling mistakes. * Lazy computation. * FrozenConfigDict() class which is immutable and hashable. * Type safety. * "Did you mean" functionality. * Human readable printing (with valid references and cycles), using valid YAML format. * Fields can be passed as keyword arguments using the `**` operator. * There are two exceptions to the strong type-safety of the ConfigDict. `int` values can be passed in to fields of type `float`. In such a case, the value is type-converted to a `float` before being stored. Similarly, all string types (including Unicode strings) can be stored in fields of type `str` or `unicode`. ### Basic Usage ```python import ml_collections cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg.integer_field) # Prints 123. print(cfg['integer_field']) # Prints 123 as well. try: cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. except TypeError as e: print(e) cfg.float_field = 12 # Works: `Int` types can be assigned to `Float`. cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. print(cfg) ``` ### FrozenConfigDict A `FrozenConfigDict`is an immutable, hashable type of `ConfigDict`: ```python import ml_collections initial_dictionary = { 'int': 1, 'list': [1, 2], 'tuple': (1, 2, 3), 'set': {1, 2, 3, 4}, 'dict_tuple_list': {'tuple_list': ([1, 2], 3)} } cfg = ml_collections.ConfigDict(initial_dictionary) frozen_dict = ml_collections.FrozenConfigDict(initial_dictionary) print(frozen_dict.tuple) # Prints tuple (1, 2, 3) print(frozen_dict.list) # Prints tuple (1, 2) print(frozen_dict.set) # Prints frozenset {1, 2, 3, 4} print(frozen_dict.dict_tuple_list.tuple_list[0]) # Prints tuple (1, 2) frozen_cfg = ml_collections.FrozenConfigDict(cfg) print(frozen_cfg == frozen_dict) # True print(hash(frozen_cfg) == hash(frozen_dict)) # True try: frozen_dict.int = 2 # Raises TypeError as FrozenConfigDict is immutable. except AttributeError as e: print(e) # Converting between `FrozenConfigDict` and `ConfigDict`: thawed_frozen_cfg = ml_collections.ConfigDict(frozen_dict) print(thawed_frozen_cfg == cfg) # True frozen_cfg_to_cfg = frozen_dict.as_configdict() print(frozen_cfg_to_cfg == cfg) # True ``` ### FieldReferences and placeholders A `FieldReference` is useful for having multiple fields use the same value. It can also be used for [lazy computation](#lazy-computation). You can use `placeholder()` as a shortcut to create a `FieldReference` (field) with a `None` default value. This is useful if a program uses optional configuration fields. ```python import ml_collections from ml_collections.config_dict import config_dict placeholder = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict() cfg.placeholder = placeholder cfg.optional = config_dict.placeholder(int) cfg.nested = ml_collections.ConfigDict() cfg.nested.placeholder = placeholder try: cfg.optional = 'tom' # Raises Type error as this field is an integer. except TypeError as e: print(e) cfg.optional = 1555 # Works fine. cfg.placeholder = 1 # Changes the value of both placeholder and # nested.placeholder fields. print(cfg) ``` Note that the indirection provided by `FieldReference`s will be lost if accessed through a `ConfigDict`. ```python import ml_collections placeholder = ml_collections.FieldReference(0) cfg.field1 = placeholder cfg.field2 = placeholder # This field will be tied to cfg.field1. cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. ``` ### Lazy computation Using a `FieldReference` in a standard operation (addition, subtraction, multiplication, etc...) will return another `FieldReference` that points to the original's value. You can use `FieldReference.get()` to execute the operations and get the reference's computed value, and `FieldReference.set()` to change the original reference's value. ```python import ml_collections ref = ml_collections.FieldReference(1) print(ref.get()) # Prints 1 add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 # Addition is lazily computed for FieldReferences so changing ref will change # the value that is used to compute add_ten. ref.set(5) print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 ``` If a `FieldReference` has `None` as its original value, or any operation has an argument of `None`, then the lazy computation will evaluate to `None`. We can also use fields in a `ConfigDict` in lazy computation. In this case a field will only be lazily evaluated if `ConfigDict.get_ref()` is used to get it. ```python import ml_collections config = ml_collections.ConfigDict() config.reference_field = ml_collections.FieldReference(1) config.integer_field = 2 config.float_field = 2.5 # No lazy evaluatuations because we didn't use get_ref() config.no_lazy = config.integer_field * config.float_field # This will lazily evaluate ONLY config.integer_field config.lazy_integer = config.get_ref('integer_field') * config.float_field # This will lazily evaluate ONLY config.float_field config.lazy_float = config.integer_field * config.get_ref('float_field') # This will lazily evaluate BOTH config.integer_field and config.float_Field config.lazy_both = (config.get_ref('integer_field') * config.get_ref('float_field')) config.integer_field = 3 print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value print(config.lazy_integer) # Prints 7.5 config.float_field = 3.5 print(config.lazy_float) # Prints 7.0 print(config.lazy_both) # Prints 10.5 ``` #### Changing lazily computed values Lazily computed values in a ConfigDict can be overridden in the same way as regular values. The reference to the `FieldReference` used for the lazy computation will be lost and all computations downstream in the reference graph will use the new value. ```python import ml_collections config = ml_collections.ConfigDict() config.reference = 1 config.reference_0 = config.get_ref('reference') + 10 config.reference_1 = config.get_ref('reference') + 20 config.reference_1_0 = config.get_ref('reference_1') + 100 print(config.reference) # Prints 1. print(config.reference_0) # Prints 11. print(config.reference_1) # Prints 21. print(config.reference_1_0) # Prints 121. config.reference_1 = 30 print(config.reference) # Prints 1 (unchanged). print(config.reference_0) # Prints 11 (unchanged). print(config.reference_1) # Prints 30. print(config.reference_1_0) # Prints 130. ``` #### Cycles You cannot create cycles using references. Fortunately [the only way](#changing-lazily-computed-values) to create a cycle is by assigning a computed field to one that *is not* the result of computation. This is forbidden: ```python import ml_collections from ml_collections.config_dict import config_dict config = ml_collections.ConfigDict() config.integer_field = 1 config.bigger_integer_field = config.get_ref('integer_field') + 10 try: # Raises a MutabilityError because setting config.integer_field would # cause a cycle. config.integer_field = config.get_ref('bigger_integer_field') + 2 except config_dict.MutabilityError as e: print(e) ``` ### Advanced usage Here are some more advanced examples showing lazy computation with different operators and data types. ```python import ml_collections config = ml_collections.ConfigDict() config.float_field = 12.6 config.integer_field = 123 config.list_field = [0, 1, 2] config.float_multiply_field = config.get_ref('float_field') * 3 print(config.float_multiply_field) # Prints 37.8 config.float_field = 10.0 print(config.float_multiply_field) # Prints 30.0 config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] config.list_field = [-1] print(config.longer_list_field) # Prints [-1, 3, 4, 5] # Both operands can be references config.ref_subtraction = ( config.get_ref('float_field') - config.get_ref('integer_field')) print(config.ref_subtraction) # Prints -113.0 config.integer_field = 10 print(config.ref_subtraction) # Prints 0.0 ``` ### Equality checking You can use `==` and `.eq_as_configdict()` to check equality among `ConfigDict` and `FrozenConfigDict` objects. ```python import ml_collections dict_1 = {'list': [1, 2]} dict_2 = {'list': (1, 2)} cfg_1 = ml_collections.ConfigDict(dict_1) frozen_cfg_1 = ml_collections.FrozenConfigDict(dict_1) frozen_cfg_2 = ml_collections.FrozenConfigDict(dict_2) # True because FrozenConfigDict converts lists to tuples print(frozen_cfg_1.items() == frozen_cfg_2.items()) # False because == distinguishes the underlying difference print(frozen_cfg_1 == frozen_cfg_2) # False because == distinguishes these types print(frozen_cfg_1 == cfg_1) # But eq_as_configdict() treats both as ConfigDict, so these are True: print(frozen_cfg_1.eq_as_configdict(cfg_1)) print(cfg_1.eq_as_configdict(frozen_cfg_1)) ``` ### Equality checking with lazy computation Equality checks see if the computed values are the same. Equality is satisfied if two sets of computations are different as long as they result in the same value. ```python import ml_collections cfg_1 = ml_collections.ConfigDict() cfg_1.a = 1 cfg_1.b = cfg_1.get_ref('a') + 2 cfg_2 = ml_collections.ConfigDict() cfg_2.a = 1 cfg_2.b = cfg_2.get_ref('a') * 3 # True because all computed values are the same print(cfg_1 == cfg_2) ``` ### Locking and copying Here is an example with `lock()` and `deepcopy()`: ```python import copy import ml_collections cfg = ml_collections.ConfigDict() cfg.integer_field = 123 # Locking prohibits the addition and deletion of new fields but allows # modification of existing values. cfg.lock() try: cfg.integer_field = 124 # Raises AttributeError and suggests valid field. except AttributeError as e: print(e) with cfg.unlocked(): cfg.integer_field = 1555 # Works fine too. # Get a copy of the config dict. new_cfg = copy.deepcopy(cfg) new_cfg.integer_field = -123 # Works fine. print(cfg) ``` ### Dictionary attributes and initialization ```python import ml_collections referenced_dict = {'inner_float': 3.14} d = { 'referenced_dict_1': referenced_dict, 'referenced_dict_2': referenced_dict, 'list_containing_dict': [{'key': 'value'}], } # We can initialize on a dictionary cfg = ml_collections.ConfigDict(d) # Reference structure is preserved print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2)) # True # And the dict attributes have been converted to ConfigDict print(type(cfg.referenced_dict_1)) # ConfigDict # However, the initialization does not look inside of lists, so dicts inside # lists are not converted to ConfigDict print(type(cfg.list_containing_dict[0])) # dict ``` ### More Examples For more examples, take a look at [`ml_collections/config_dict/examples/`](https://github.com/google/ml_collections/tree/master/ml_collections/config_dict/examples) For examples and gotchas specifically about initializing a ConfigDict, see [`ml_collections/config_dict/examples/config_dict_initialization.py`](https://github.com/google/ml_collections/blob/master/ml_collections/config_dict/examples/config_dict_initialization.py). ## Config Flags This library adds flag definitions to `absl.flags` to handle config files. It does not wrap `absl.flags` so if using any standard flag definitions alongside config file flags, users must also import `absl.flags`. Currently, this module adds two new flag types, namely `DEFINE_config_file` which accepts a path to a Python file that generates a configuration, and `DEFINE_config_dict` which accepts a configuration directly. Configurations are dict-like structures (see [ConfigDict](#configdict)) whose nested elements can be overridden using special command-line flags. See the examples below for more details. ### Usage Use `ml_collections.config_flags` alongside `absl.flags`. For example: `script.py`: ```python from absl import app from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file('my_config') def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run(main) ``` `config.py`: ```python # Note that this is a valid Python script. # get_config() can return an arbitrary dict-like object. However, it is advised # to use ml_collections.ConfigDict. # See ml_collections/config_dict/examples/config_dict_basic.py import ml_collections def get_config(): config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) return config ``` Now, after running: ```bash python script.py --my_config=config.py \ --my_config.field1=8 \ --my_config.nested.field=2.1 \ --my_config.tuple='(1, 2, (1, 2))' ``` we get: ``` field1: 8 field2: tom nested: field: 2.1 tuple: !!python/tuple - 1 - 2 - !!python/tuple - 1 - 2 ``` Usage of `DEFINE_config_dict` is similar to `DEFINE_config_file`, the main difference is the configuration is defined in `script.py` instead of in a separate file. `script.py`: ```python from absl import app from absl import flags import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) FLAGS = flags.FLAGS config_flags.DEFINE_config_dict('my_config', config) def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run() ``` `config_file` flags are compatible with the command-line flag syntax. All the following options are supported for non-boolean values in configurations: * `-(-)config.field=value` * `-(-)config.field value` Options for boolean values are slightly different: * `-(-)config.boolean_field`: set boolean value to True. * `-(-)noconfig.boolean_field`: set boolean value to False. * `-(-)config.boolean_field=value`: `value` is `true`, `false`, `True` or `False`. Note that `-(-)config.boolean_field value` is not supported. ### Parameterising the get_config() function It's sometimes useful to be able to pass parameters into `get_config`, and change what is returned based on this configuration. One example is if you are grid searching over parameters which have a different hierarchical structure - the flag needs to be present in the resulting ConfigDict. It would be possible to include the union of all possible leaf values in your ConfigDict, but this produces a confusing config result as you have to remember which parameters will actually have an effect and which won't. A better system is to pass some configuration, indicating which structure of ConfigDict should be returned. An example is the following config file: ```python import ml_collections def get_config(config_string): possible_structures = { 'linear': ml_collections.ConfigDict({ 'model_constructor': 'snt.Linear', 'model_config': ml_collections.ConfigDict({ 'output_size': 42, }), 'lstm': ml_collections.ConfigDict({ 'model_constructor': 'snt.LSTM', 'model_config': ml_collections.ConfigDict({ 'hidden_size': 108, }) }) } return possible_structures[config_string] ``` The value of `config_string` will be anything that is to the right of the first colon in the config file path, if one exists. If no colon exists, no value is passed to `get_config` (producing a TypeError if `get_config` expects a value.) The above example can be run like: ```bash python script.py -- --config=path_to_config.py:linear \ --config.model_config.output_size=256 ``` or like: ```bash python script.py -- --config=path_to_config.py:lstm \ --config.model_config.hidden_size=512 ``` ### Additional features * Loads any valid python script which defines `get_config()` function returning any python object. * Automatic locking of the loaded object, if the loaded object defines a callable `.lock()` method. * Supports command-line overriding of arbitrarily nested values in dict-like objects (with key/attribute based getters/setters) of the following types: * `types.IntType` (integer) * `types.FloatType` (float) * `types.BooleanType` (bool) * `types.StringType` (string) * `types.TupleType` (tuple) * Overriding is type safe. * Overriding of `TupleType` can be done by passing in the `tuple` as a string (see the example in the [Usage](#usage) section). * The overriding `tuple` object can be of a different size and have different types than the original. Nested tuples are also supported. ## Authors * Sergio Gómez Colmenarejo - sergomez@google.com * Wojciech Marian Czarnecki - lejlot@google.com * Nicholas Watters * Mohit Reddy - mohitreddy@google.com ml_collections-0.1.1/requirements.txt0000640000175000017500000000013414174507605017334 0ustar nileshnileshabsl-py PyYAML six contextlib2 dataclasses;python_version<'3.7' typing;python_version<'3.5' ml_collections-0.1.1/setup.py0000640000175000017500000000410514174510431015553 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Setup for pip package.""" from setuptools import find_namespace_packages from setuptools import setup def _parse_requirements(requirements_txt_path): with open(requirements_txt_path) as fp: return fp.read().splitlines() _VERSION = '0.1.1' setup( name='ml_collections', version=_VERSION, author='ML Collections Authors', author_email='ml-collections@google.com', description='ML Collections is a library of Python collections designed for ML usecases.', long_description=open('README.md').read(), long_description_content_type='text/markdown', url='https://github.com/google/ml_collections', license='Apache 2.0', # Contained modules and scripts. packages=find_namespace_packages(exclude=['*_test.py']), install_requires=_parse_requirements('requirements.txt'), tests_require=_parse_requirements('requirements-test.txt'), python_requires='>=2.6', include_package_data=True, zip_safe=False, # PyPI package information. classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', ], ) ml_collections-0.1.1/setup.cfg0000640000175000017500000000004614174510450015663 0ustar nileshnilesh[egg_info] tag_build = tag_date = 0 ml_collections-0.1.1/AUTHORS0000640000175000017500000000052014174507605015117 0ustar nileshnilesh# This is the list of ML Collections's significant contributors. # # This does not necessarily list everyone who has contributed code, # especially since many employees of one corporation may be contributing. # To see the full list of contributors, see the revision history in # source control. DeepMind Technologies Limited Google LLCml_collections-0.1.1/LICENSE0000640000175000017500000002613514174507605015066 0ustar nileshnilesh Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.ml_collections-0.1.1/ml_collections.egg-info/0000750000175000017500000000000014174510450020541 5ustar nileshnileshml_collections-0.1.1/ml_collections.egg-info/not-zip-safe0000640000175000017500000000000114174507662023003 0ustar nileshnilesh ml_collections-0.1.1/ml_collections.egg-info/top_level.txt0000640000175000017500000000003714174510450023274 0ustar nileshnileshbuild dist docs ml_collections ml_collections-0.1.1/ml_collections.egg-info/PKG-INFO0000640000175000017500000004546714174510450021657 0ustar nileshnileshMetadata-Version: 2.1 Name: ml-collections Version: 0.1.1 Summary: ML Collections is a library of Python collections designed for ML usecases. Home-page: https://github.com/google/ml_collections Author: ML Collections Authors Author-email: ml-collections@google.com License: Apache 2.0 Platform: UNKNOWN Classifier: Development Status :: 4 - Beta Classifier: Intended Audience :: Developers Classifier: Intended Audience :: Science/Research Classifier: License :: OSI Approved :: Apache Software License Classifier: Programming Language :: Python Classifier: Topic :: Scientific/Engineering Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence Classifier: Topic :: Software Development :: Libraries Classifier: Topic :: Software Development :: Libraries :: Python Modules Requires-Python: >=2.6 Description-Content-Type: text/markdown License-File: LICENSE License-File: AUTHORS # ML Collections ML Collections is a library of Python Collections designed for ML use cases. [![Documentation Status](https://readthedocs.org/projects/ml-collections/badge/?version=latest)](https://ml-collections.readthedocs.io/en/latest/?badge=latest) [![PyPI version](https://badge.fury.io/py/ml-collections.svg)](https://badge.fury.io/py/ml-collections) [![Build Status](https://github.com/google/ml_collections/workflows/Python%20package/badge.svg)](https://github.com/google/ml_collections/actions?query=workflow%3A%22Python+package%22) ## ConfigDict The two classes called `ConfigDict` and `FrozenConfigDict` are "dict-like" data structures with dot access to nested elements. Together, they are supposed to be used as a main way of expressing configurations of experiments and models. This document describes example usage of `ConfigDict`, `FrozenConfigDict`, `FieldReference`. ### Features * Dot-based access to fields. * Locking mechanism to prevent spelling mistakes. * Lazy computation. * FrozenConfigDict() class which is immutable and hashable. * Type safety. * "Did you mean" functionality. * Human readable printing (with valid references and cycles), using valid YAML format. * Fields can be passed as keyword arguments using the `**` operator. * There are two exceptions to the strong type-safety of the ConfigDict. `int` values can be passed in to fields of type `float`. In such a case, the value is type-converted to a `float` before being stored. Similarly, all string types (including Unicode strings) can be stored in fields of type `str` or `unicode`. ### Basic Usage ```python import ml_collections cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg.integer_field) # Prints 123. print(cfg['integer_field']) # Prints 123 as well. try: cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. except TypeError as e: print(e) cfg.float_field = 12 # Works: `Int` types can be assigned to `Float`. cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. print(cfg) ``` ### FrozenConfigDict A `FrozenConfigDict`is an immutable, hashable type of `ConfigDict`: ```python import ml_collections initial_dictionary = { 'int': 1, 'list': [1, 2], 'tuple': (1, 2, 3), 'set': {1, 2, 3, 4}, 'dict_tuple_list': {'tuple_list': ([1, 2], 3)} } cfg = ml_collections.ConfigDict(initial_dictionary) frozen_dict = ml_collections.FrozenConfigDict(initial_dictionary) print(frozen_dict.tuple) # Prints tuple (1, 2, 3) print(frozen_dict.list) # Prints tuple (1, 2) print(frozen_dict.set) # Prints frozenset {1, 2, 3, 4} print(frozen_dict.dict_tuple_list.tuple_list[0]) # Prints tuple (1, 2) frozen_cfg = ml_collections.FrozenConfigDict(cfg) print(frozen_cfg == frozen_dict) # True print(hash(frozen_cfg) == hash(frozen_dict)) # True try: frozen_dict.int = 2 # Raises TypeError as FrozenConfigDict is immutable. except AttributeError as e: print(e) # Converting between `FrozenConfigDict` and `ConfigDict`: thawed_frozen_cfg = ml_collections.ConfigDict(frozen_dict) print(thawed_frozen_cfg == cfg) # True frozen_cfg_to_cfg = frozen_dict.as_configdict() print(frozen_cfg_to_cfg == cfg) # True ``` ### FieldReferences and placeholders A `FieldReference` is useful for having multiple fields use the same value. It can also be used for [lazy computation](#lazy-computation). You can use `placeholder()` as a shortcut to create a `FieldReference` (field) with a `None` default value. This is useful if a program uses optional configuration fields. ```python import ml_collections from ml_collections.config_dict import config_dict placeholder = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict() cfg.placeholder = placeholder cfg.optional = config_dict.placeholder(int) cfg.nested = ml_collections.ConfigDict() cfg.nested.placeholder = placeholder try: cfg.optional = 'tom' # Raises Type error as this field is an integer. except TypeError as e: print(e) cfg.optional = 1555 # Works fine. cfg.placeholder = 1 # Changes the value of both placeholder and # nested.placeholder fields. print(cfg) ``` Note that the indirection provided by `FieldReference`s will be lost if accessed through a `ConfigDict`. ```python import ml_collections placeholder = ml_collections.FieldReference(0) cfg.field1 = placeholder cfg.field2 = placeholder # This field will be tied to cfg.field1. cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. ``` ### Lazy computation Using a `FieldReference` in a standard operation (addition, subtraction, multiplication, etc...) will return another `FieldReference` that points to the original's value. You can use `FieldReference.get()` to execute the operations and get the reference's computed value, and `FieldReference.set()` to change the original reference's value. ```python import ml_collections ref = ml_collections.FieldReference(1) print(ref.get()) # Prints 1 add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 # Addition is lazily computed for FieldReferences so changing ref will change # the value that is used to compute add_ten. ref.set(5) print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 ``` If a `FieldReference` has `None` as its original value, or any operation has an argument of `None`, then the lazy computation will evaluate to `None`. We can also use fields in a `ConfigDict` in lazy computation. In this case a field will only be lazily evaluated if `ConfigDict.get_ref()` is used to get it. ```python import ml_collections config = ml_collections.ConfigDict() config.reference_field = ml_collections.FieldReference(1) config.integer_field = 2 config.float_field = 2.5 # No lazy evaluatuations because we didn't use get_ref() config.no_lazy = config.integer_field * config.float_field # This will lazily evaluate ONLY config.integer_field config.lazy_integer = config.get_ref('integer_field') * config.float_field # This will lazily evaluate ONLY config.float_field config.lazy_float = config.integer_field * config.get_ref('float_field') # This will lazily evaluate BOTH config.integer_field and config.float_Field config.lazy_both = (config.get_ref('integer_field') * config.get_ref('float_field')) config.integer_field = 3 print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value print(config.lazy_integer) # Prints 7.5 config.float_field = 3.5 print(config.lazy_float) # Prints 7.0 print(config.lazy_both) # Prints 10.5 ``` #### Changing lazily computed values Lazily computed values in a ConfigDict can be overridden in the same way as regular values. The reference to the `FieldReference` used for the lazy computation will be lost and all computations downstream in the reference graph will use the new value. ```python import ml_collections config = ml_collections.ConfigDict() config.reference = 1 config.reference_0 = config.get_ref('reference') + 10 config.reference_1 = config.get_ref('reference') + 20 config.reference_1_0 = config.get_ref('reference_1') + 100 print(config.reference) # Prints 1. print(config.reference_0) # Prints 11. print(config.reference_1) # Prints 21. print(config.reference_1_0) # Prints 121. config.reference_1 = 30 print(config.reference) # Prints 1 (unchanged). print(config.reference_0) # Prints 11 (unchanged). print(config.reference_1) # Prints 30. print(config.reference_1_0) # Prints 130. ``` #### Cycles You cannot create cycles using references. Fortunately [the only way](#changing-lazily-computed-values) to create a cycle is by assigning a computed field to one that *is not* the result of computation. This is forbidden: ```python import ml_collections from ml_collections.config_dict import config_dict config = ml_collections.ConfigDict() config.integer_field = 1 config.bigger_integer_field = config.get_ref('integer_field') + 10 try: # Raises a MutabilityError because setting config.integer_field would # cause a cycle. config.integer_field = config.get_ref('bigger_integer_field') + 2 except config_dict.MutabilityError as e: print(e) ``` ### Advanced usage Here are some more advanced examples showing lazy computation with different operators and data types. ```python import ml_collections config = ml_collections.ConfigDict() config.float_field = 12.6 config.integer_field = 123 config.list_field = [0, 1, 2] config.float_multiply_field = config.get_ref('float_field') * 3 print(config.float_multiply_field) # Prints 37.8 config.float_field = 10.0 print(config.float_multiply_field) # Prints 30.0 config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] config.list_field = [-1] print(config.longer_list_field) # Prints [-1, 3, 4, 5] # Both operands can be references config.ref_subtraction = ( config.get_ref('float_field') - config.get_ref('integer_field')) print(config.ref_subtraction) # Prints -113.0 config.integer_field = 10 print(config.ref_subtraction) # Prints 0.0 ``` ### Equality checking You can use `==` and `.eq_as_configdict()` to check equality among `ConfigDict` and `FrozenConfigDict` objects. ```python import ml_collections dict_1 = {'list': [1, 2]} dict_2 = {'list': (1, 2)} cfg_1 = ml_collections.ConfigDict(dict_1) frozen_cfg_1 = ml_collections.FrozenConfigDict(dict_1) frozen_cfg_2 = ml_collections.FrozenConfigDict(dict_2) # True because FrozenConfigDict converts lists to tuples print(frozen_cfg_1.items() == frozen_cfg_2.items()) # False because == distinguishes the underlying difference print(frozen_cfg_1 == frozen_cfg_2) # False because == distinguishes these types print(frozen_cfg_1 == cfg_1) # But eq_as_configdict() treats both as ConfigDict, so these are True: print(frozen_cfg_1.eq_as_configdict(cfg_1)) print(cfg_1.eq_as_configdict(frozen_cfg_1)) ``` ### Equality checking with lazy computation Equality checks see if the computed values are the same. Equality is satisfied if two sets of computations are different as long as they result in the same value. ```python import ml_collections cfg_1 = ml_collections.ConfigDict() cfg_1.a = 1 cfg_1.b = cfg_1.get_ref('a') + 2 cfg_2 = ml_collections.ConfigDict() cfg_2.a = 1 cfg_2.b = cfg_2.get_ref('a') * 3 # True because all computed values are the same print(cfg_1 == cfg_2) ``` ### Locking and copying Here is an example with `lock()` and `deepcopy()`: ```python import copy import ml_collections cfg = ml_collections.ConfigDict() cfg.integer_field = 123 # Locking prohibits the addition and deletion of new fields but allows # modification of existing values. cfg.lock() try: cfg.integer_field = 124 # Raises AttributeError and suggests valid field. except AttributeError as e: print(e) with cfg.unlocked(): cfg.integer_field = 1555 # Works fine too. # Get a copy of the config dict. new_cfg = copy.deepcopy(cfg) new_cfg.integer_field = -123 # Works fine. print(cfg) ``` ### Dictionary attributes and initialization ```python import ml_collections referenced_dict = {'inner_float': 3.14} d = { 'referenced_dict_1': referenced_dict, 'referenced_dict_2': referenced_dict, 'list_containing_dict': [{'key': 'value'}], } # We can initialize on a dictionary cfg = ml_collections.ConfigDict(d) # Reference structure is preserved print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2)) # True # And the dict attributes have been converted to ConfigDict print(type(cfg.referenced_dict_1)) # ConfigDict # However, the initialization does not look inside of lists, so dicts inside # lists are not converted to ConfigDict print(type(cfg.list_containing_dict[0])) # dict ``` ### More Examples For more examples, take a look at [`ml_collections/config_dict/examples/`](https://github.com/google/ml_collections/tree/master/ml_collections/config_dict/examples) For examples and gotchas specifically about initializing a ConfigDict, see [`ml_collections/config_dict/examples/config_dict_initialization.py`](https://github.com/google/ml_collections/blob/master/ml_collections/config_dict/examples/config_dict_initialization.py). ## Config Flags This library adds flag definitions to `absl.flags` to handle config files. It does not wrap `absl.flags` so if using any standard flag definitions alongside config file flags, users must also import `absl.flags`. Currently, this module adds two new flag types, namely `DEFINE_config_file` which accepts a path to a Python file that generates a configuration, and `DEFINE_config_dict` which accepts a configuration directly. Configurations are dict-like structures (see [ConfigDict](#configdict)) whose nested elements can be overridden using special command-line flags. See the examples below for more details. ### Usage Use `ml_collections.config_flags` alongside `absl.flags`. For example: `script.py`: ```python from absl import app from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file('my_config') def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run(main) ``` `config.py`: ```python # Note that this is a valid Python script. # get_config() can return an arbitrary dict-like object. However, it is advised # to use ml_collections.ConfigDict. # See ml_collections/config_dict/examples/config_dict_basic.py import ml_collections def get_config(): config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) return config ``` Now, after running: ```bash python script.py --my_config=config.py \ --my_config.field1=8 \ --my_config.nested.field=2.1 \ --my_config.tuple='(1, 2, (1, 2))' ``` we get: ``` field1: 8 field2: tom nested: field: 2.1 tuple: !!python/tuple - 1 - 2 - !!python/tuple - 1 - 2 ``` Usage of `DEFINE_config_dict` is similar to `DEFINE_config_file`, the main difference is the configuration is defined in `script.py` instead of in a separate file. `script.py`: ```python from absl import app from absl import flags import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) FLAGS = flags.FLAGS config_flags.DEFINE_config_dict('my_config', config) def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run() ``` `config_file` flags are compatible with the command-line flag syntax. All the following options are supported for non-boolean values in configurations: * `-(-)config.field=value` * `-(-)config.field value` Options for boolean values are slightly different: * `-(-)config.boolean_field`: set boolean value to True. * `-(-)noconfig.boolean_field`: set boolean value to False. * `-(-)config.boolean_field=value`: `value` is `true`, `false`, `True` or `False`. Note that `-(-)config.boolean_field value` is not supported. ### Parameterising the get_config() function It's sometimes useful to be able to pass parameters into `get_config`, and change what is returned based on this configuration. One example is if you are grid searching over parameters which have a different hierarchical structure - the flag needs to be present in the resulting ConfigDict. It would be possible to include the union of all possible leaf values in your ConfigDict, but this produces a confusing config result as you have to remember which parameters will actually have an effect and which won't. A better system is to pass some configuration, indicating which structure of ConfigDict should be returned. An example is the following config file: ```python import ml_collections def get_config(config_string): possible_structures = { 'linear': ml_collections.ConfigDict({ 'model_constructor': 'snt.Linear', 'model_config': ml_collections.ConfigDict({ 'output_size': 42, }), 'lstm': ml_collections.ConfigDict({ 'model_constructor': 'snt.LSTM', 'model_config': ml_collections.ConfigDict({ 'hidden_size': 108, }) }) } return possible_structures[config_string] ``` The value of `config_string` will be anything that is to the right of the first colon in the config file path, if one exists. If no colon exists, no value is passed to `get_config` (producing a TypeError if `get_config` expects a value.) The above example can be run like: ```bash python script.py -- --config=path_to_config.py:linear \ --config.model_config.output_size=256 ``` or like: ```bash python script.py -- --config=path_to_config.py:lstm \ --config.model_config.hidden_size=512 ``` ### Additional features * Loads any valid python script which defines `get_config()` function returning any python object. * Automatic locking of the loaded object, if the loaded object defines a callable `.lock()` method. * Supports command-line overriding of arbitrarily nested values in dict-like objects (with key/attribute based getters/setters) of the following types: * `types.IntType` (integer) * `types.FloatType` (float) * `types.BooleanType` (bool) * `types.StringType` (string) * `types.TupleType` (tuple) * Overriding is type safe. * Overriding of `TupleType` can be done by passing in the `tuple` as a string (see the example in the [Usage](#usage) section). * The overriding `tuple` object can be of a different size and have different types than the original. Nested tuples are also supported. ## Authors * Sergio Gómez Colmenarejo - sergomez@google.com * Wojciech Marian Czarnecki - lejlot@google.com * Nicholas Watters * Mohit Reddy - mohitreddy@google.com ml_collections-0.1.1/ml_collections.egg-info/dependency_links.txt0000640000175000017500000000000114174510450024610 0ustar nileshnilesh ml_collections-0.1.1/ml_collections.egg-info/SOURCES.txt0000640000175000017500000000416014174510450022427 0ustar nileshnileshAUTHORS LICENSE MANIFEST.in README.md requirements-test.txt requirements.txt setup.py docs/conf.py ml_collections/__init__.py ml_collections.egg-info/PKG-INFO ml_collections.egg-info/SOURCES.txt ml_collections.egg-info/dependency_links.txt ml_collections.egg-info/not-zip-safe ml_collections.egg-info/requires.txt ml_collections.egg-info/top_level.txt ml_collections/config_dict/__init__.py ml_collections/config_dict/config_dict.py ml_collections/config_dict/examples/config.py ml_collections/config_dict/examples/config_dict_advanced.py ml_collections/config_dict/examples/config_dict_basic.py ml_collections/config_dict/examples/config_dict_initialization.py ml_collections/config_dict/examples/config_dict_lock.py ml_collections/config_dict/examples/config_dict_placeholder.py ml_collections/config_dict/examples/examples_test.py ml_collections/config_dict/examples/field_reference.py ml_collections/config_dict/examples/frozen_config_dict.py ml_collections/config_dict/tests/config_dict_test.py ml_collections/config_dict/tests/field_reference_test.py ml_collections/config_dict/tests/frozen_config_dict_test.py ml_collections/config_flags/__init__.py ml_collections/config_flags/config_flags.py ml_collections/config_flags/tuple_parser.py ml_collections/config_flags/examples/config.py ml_collections/config_flags/examples/define_config_dataclass_basic.py ml_collections/config_flags/examples/define_config_dict_basic.py ml_collections/config_flags/examples/define_config_file_basic.py ml_collections/config_flags/examples/examples_test.py ml_collections/config_flags/examples/parameterised_config.py ml_collections/config_flags/tests/config_overriding_test.py ml_collections/config_flags/tests/configdict_config.py ml_collections/config_flags/tests/dataclass_overriding_test.py ml_collections/config_flags/tests/fieldreference_config.py ml_collections/config_flags/tests/ioerror_config.py ml_collections/config_flags/tests/mini_config.py ml_collections/config_flags/tests/mock_config.py ml_collections/config_flags/tests/parameterised_config.py ml_collections/config_flags/tests/typeerror_config.py ml_collections/config_flags/tests/valueerror_config.pyml_collections-0.1.1/ml_collections.egg-info/requires.txt0000640000175000017500000000015014174510450023136 0ustar nileshnileshPyYAML absl-py contextlib2 six [:python_version < "3.5"] typing [:python_version < "3.7"] dataclasses ml_collections-0.1.1/requirements-test.txt0000640000175000017500000000000514174507605020306 0ustar nileshnileshmock ml_collections-0.1.1/MANIFEST.in0000640000175000017500000000006714174507605015613 0ustar nileshnileshinclude requirements-test.txt include requirements.txt ml_collections-0.1.1/ml_collections/0000750000175000017500000000000014174510450017047 5ustar nileshnileshml_collections-0.1.1/ml_collections/config_dict/0000750000175000017500000000000014174510450021317 5ustar nileshnileshml_collections-0.1.1/ml_collections/config_dict/tests/0000750000175000017500000000000014174510450022461 5ustar nileshnileshml_collections-0.1.1/ml_collections/config_dict/tests/field_reference_test.py0000640000175000017500000005774014174507605027221 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for ml_collections.FieldReference.""" import operator from absl.testing import absltest from absl.testing import parameterized import ml_collections from ml_collections.config_dict import config_dict class FieldReferenceTest(parameterized.TestCase): def _test_binary_operator(self, initial_value, other_value, op, true_value, new_initial_value, new_true_value, assert_fn=None): """Helper for testing binary operators. Generally speaking this checks that: 1. `op(initial_value, other_value) COMP true_value` 2. `op(new_initial_value, other_value) COMP new_true_value where `COMP` is the comparison function defined by `assert_fn`. Args: initial_value: Initial value for the `FieldReference`, this is the first argument for the binary operator. other_value: The second argument for the binary operator. op: The binary operator. true_value: The expected output of the binary operator. new_initial_value: The value that the `FieldReference` is changed to. new_true_value: The expected output of the binary operator after the `FieldReference` has changed. assert_fn: Function used to check the output values. """ if assert_fn is None: assert_fn = self.assertEqual ref = ml_collections.FieldReference(initial_value) new_ref = op(ref, other_value) assert_fn(new_ref.get(), true_value) config = ml_collections.ConfigDict() config.a = initial_value config.b = other_value config.result = op(config.get_ref('a'), config.b) assert_fn(config.result, true_value) config.a = new_initial_value assert_fn(config.result, new_true_value) def _test_unary_operator(self, initial_value, op, true_value, new_initial_value, new_true_value, assert_fn=None): """Helper for testing unary operators. Generally speaking this checks that: 1. `op(initial_value) COMP true_value` 2. `op(new_initial_value) COMP new_true_value where `COMP` is the comparison function defined by `assert_fn`. Args: initial_value: Initial value for the `FieldReference`, this is the first argument for the unary operator. op: The unary operator. true_value: The expected output of the unary operator. new_initial_value: The value that the `FieldReference` is changed to. new_true_value: The expected output of the unary operator after the `FieldReference` has changed. assert_fn: Function used to check the output values. """ if assert_fn is None: assert_fn = self.assertEqual ref = ml_collections.FieldReference(initial_value) new_ref = op(ref) assert_fn(new_ref.get(), true_value) config = ml_collections.ConfigDict() config.a = initial_value config.result = op(config.get_ref('a')) assert_fn(config.result, true_value) config.a = new_initial_value assert_fn(config.result, new_true_value) def testBasic(self): ref = ml_collections.FieldReference(1) self.assertEqual(ref.get(), 1) def testGetRef(self): config = ml_collections.ConfigDict() config.a = 1. config.b = config.get_ref('a') + 10 config.c = config.get_ref('b') + 10 self.assertEqual(config.c, 21.0) def testFunction(self): def fn(x): return x + 5 config = ml_collections.ConfigDict() config.a = 1 config.b = fn(config.get_ref('a')) config.c = fn(config.get_ref('b')) self.assertEqual(config.b, 6) self.assertEqual(config.c, 11) config.a = 2 self.assertEqual(config.b, 7) self.assertEqual(config.c, 12) def testCycles(self): config = ml_collections.ConfigDict() config.a = 1. config.b = config.get_ref('a') + 10 config.c = config.get_ref('b') + 10 self.assertEqual(config.b, 11.0) self.assertEqual(config.c, 21.0) # Introduce a cycle with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = config.get_ref('c') - 1.0 # Introduce a cycle on second operand with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = ml_collections.FieldReference(5.0) + config.get_ref('c') # We can create multiple FieldReferences that all point to the same object l = [0] config = ml_collections.ConfigDict() config.a = l config.b = l config.c = config.get_ref('a') + ['c'] config.d = config.get_ref('b') + ['d'] self.assertEqual(config.c, [0, 'c']) self.assertEqual(config.d, [0, 'd']) # Make sure nothing was mutated self.assertEqual(l, [0]) self.assertEqual(config.c, [0, 'c']) config.a = [1] config.b = [2] self.assertEqual(l, [0]) self.assertEqual(config.c, [1, 'c']) self.assertEqual(config.d, [2, 'd']) @parameterized.parameters( { 'initial_value': 1, 'other_value': 2, 'true_value': 3, 'new_initial_value': 10, 'new_true_value': 12 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 4.5, 'new_initial_value': 3.7, 'new_true_value': 6.2 }, { 'initial_value': 'hello, ', 'other_value': 'world!', 'true_value': 'hello, world!', 'new_initial_value': 'foo, ', 'new_true_value': 'foo, world!' }, { 'initial_value': ['hello'], 'other_value': ['world'], 'true_value': ['hello', 'world'], 'new_initial_value': ['foo'], 'new_true_value': ['foo', 'world'] }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 15.0, 'new_initial_value': 12, 'new_true_value': 17.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 19.0 }, { 'initial_value': 5.0, 'other_value': config_dict.placeholder(float), 'true_value': None, 'new_initial_value': 8.0, 'new_true_value': None }, { 'initial_value': config_dict.placeholder(str), 'other_value': 'tail', 'true_value': None, 'new_initial_value': 'head', 'new_true_value': 'headtail' }) def testAdd(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.add, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 5, 'other_value': 3, 'true_value': 2, 'new_initial_value': -1, 'new_true_value': -4 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': -0.5, 'new_initial_value': 12.3, 'new_true_value': 9.8 }, { 'initial_value': set(['hello', 123, 4.5]), 'other_value': set([123]), 'true_value': set(['hello', 4.5]), 'new_initial_value': set([123]), 'new_true_value': set([]) }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 5.0, 'new_initial_value': 12, 'new_true_value': 7.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 5.0 }) def testSub(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.sub, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 1, 'other_value': 2, 'true_value': 2, 'new_initial_value': 3, 'new_true_value': 6 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 5.0, 'new_initial_value': 3.5, 'new_true_value': 8.75 }, { 'initial_value': ['hello'], 'other_value': 3, 'true_value': ['hello', 'hello', 'hello'], 'new_initial_value': ['foo'], 'new_true_value': ['foo', 'foo', 'foo'] }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 50.0, 'new_initial_value': 1, 'new_true_value': 5.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 84.0 }) def testMul(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.mul, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1.5, 'new_initial_value': 10, 'new_true_value': 5.0 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 0.8, 'new_initial_value': 6.3, 'new_true_value': 2.52 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 2.0, 'new_initial_value': 13, 'new_true_value': 2.6 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 17.5, 'new_true_value': 2.5 }) def testTrueDiv(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.truediv, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1, 'new_initial_value': 7, 'new_true_value': 3 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5), 'true_value': 2, 'new_initial_value': 28, 'new_true_value': 5 }, { 'initial_value': config_dict.placeholder(int), 'other_value': 7, 'true_value': None, 'new_initial_value': 25, 'new_true_value': 3 }) def testFloorDiv(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.floordiv, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 9, 'new_initial_value': 10, 'new_true_value': 100 }, { 'initial_value': 2.7, 'other_value': 3.2, 'true_value': 24.0084457245, 'new_initial_value': 6.5, 'new_true_value': 399.321543621 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5), 'true_value': 1e5, 'new_initial_value': 2, 'new_true_value': 32 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 3.0, 'true_value': None, 'new_initial_value': 7.0, 'new_true_value': 343.0 }) def testPow(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator( initial_value, other_value, operator.pow, true_value, new_initial_value, new_true_value, assert_fn=self.assertAlmostEqual) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1, 'new_initial_value': 10, 'new_true_value': 0 }, { 'initial_value': 5.3, 'other_value': 3.2, 'true_value': 2.0999999999999996, 'new_initial_value': 77, 'new_true_value': 0.2 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5), 'true_value': 0, 'new_initial_value': 32, 'new_true_value': 2 }, { 'initial_value': config_dict.placeholder(int), 'other_value': 7, 'true_value': None, 'new_initial_value': 25, 'new_true_value': 4 }) def testMod(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator( initial_value, other_value, operator.mod, true_value, new_initial_value, new_true_value, assert_fn=self.assertAlmostEqual) @parameterized.parameters( { 'initial_value': True, 'other_value': True, 'true_value': True, 'new_initial_value': False, 'new_true_value': False }, { 'initial_value': ml_collections.FieldReference(False), 'other_value': ml_collections.FieldReference(False), 'true_value': False, 'new_initial_value': True, 'new_true_value': False }, { 'initial_value': config_dict.placeholder(bool), 'other_value': True, 'true_value': None, 'new_initial_value': False, 'new_true_value': False }) def testAnd(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.and_, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': False, 'other_value': False, 'true_value': False, 'new_initial_value': True, 'new_true_value': True }, { 'initial_value': ml_collections.FieldReference(True), 'other_value': ml_collections.FieldReference(True), 'true_value': True, 'new_initial_value': False, 'new_true_value': True }, { 'initial_value': config_dict.placeholder(bool), 'other_value': False, 'true_value': None, 'new_initial_value': True, 'new_true_value': True }) def testOr(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.or_, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': False, 'other_value': True, 'true_value': True, 'new_initial_value': True, 'new_true_value': False }, { 'initial_value': ml_collections.FieldReference(True), 'other_value': ml_collections.FieldReference(True), 'true_value': False, 'new_initial_value': False, 'new_true_value': True }, { 'initial_value': config_dict.placeholder(bool), 'other_value': True, 'true_value': None, 'new_initial_value': True, 'new_true_value': False }) def testXor(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.xor, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'true_value': -3, 'new_initial_value': -22, 'new_true_value': 22 }, { 'initial_value': 15.3, 'true_value': -15.3, 'new_initial_value': -0.2, 'new_true_value': 0.2 }, { 'initial_value': ml_collections.FieldReference(7), 'true_value': ml_collections.FieldReference(-7), 'new_initial_value': 123, 'new_true_value': -123 }, { 'initial_value': config_dict.placeholder(int), 'true_value': None, 'new_initial_value': -6, 'new_true_value': 6 }) def testNeg(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, operator.neg, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': config_dict.create(attribute=2), 'true_value': 2, 'new_initial_value': config_dict.create(attribute=3), 'new_true_value': 3, }, { 'initial_value': config_dict.create(attribute={'a': 1}), 'true_value': config_dict.create(a=1), 'new_initial_value': config_dict.create(attribute={'b': 1}), 'new_true_value': config_dict.create(b=1), }, { 'initial_value': ml_collections.FieldReference(config_dict.create(attribute=2)), 'true_value': ml_collections.FieldReference(2), 'new_initial_value': config_dict.create(attribute=3), 'new_true_value': 3, }, { 'initial_value': config_dict.placeholder(config_dict.ConfigDict), 'true_value': None, 'new_initial_value': config_dict.create(attribute=3), 'new_true_value': 3, }, ) def testAttr(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, lambda x: x.attr('attribute'), true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'true_value': 3, 'new_initial_value': -101, 'new_true_value': 101 }, { 'initial_value': -15.3, 'true_value': 15.3, 'new_initial_value': 7.3, 'new_true_value': 7.3 }, { 'initial_value': ml_collections.FieldReference(-7), 'true_value': ml_collections.FieldReference(7), 'new_initial_value': 3, 'new_true_value': 3 }, { 'initial_value': config_dict.placeholder(float), 'true_value': None, 'new_initial_value': -6.25, 'new_true_value': 6.25 }) def testAbs(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, operator.abs, true_value, new_initial_value, new_true_value) def testToInt(self): self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27) ref = ml_collections.FieldReference(64.7) ref = ref.to_int() self.assertEqual(ref.get(), 64) self.assertEqual(ref._field_type, int) def testToFloat(self): self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0) ref = ml_collections.FieldReference(647) ref = ref.to_float() self.assertEqual(ref.get(), 647.0) self.assertEqual(ref._field_type, float) def testToString(self): self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0') ref = ml_collections.FieldReference(647) ref = ref.to_str() self.assertEqual(ref.get(), '647') self.assertEqual(ref._field_type, str) def testSetValue(self): ref = ml_collections.FieldReference(1.0) other = ml_collections.FieldReference(3) ref_plus_other = ref + other self.assertEqual(ref_plus_other.get(), 4.0) ref.set(2.5) self.assertEqual(ref_plus_other.get(), 5.5) other.set(110) self.assertEqual(ref_plus_other.get(), 112.5) # Type checking with self.assertRaises(TypeError): other.set('this is a string') with self.assertRaises(TypeError): other.set(ml_collections.FieldReference('this is a string')) with self.assertRaises(TypeError): other.set(ml_collections.FieldReference(None, field_type=str)) def testSetResult(self): ref = ml_collections.FieldReference(1.0) result = ref + 1.0 second_result = result + 1.0 self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 2.0) self.assertEqual(second_result.get(), 3.0) ref.set(2.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 3.0) self.assertEqual(second_result.get(), 4.0) result.set(4.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0) # All references are broken at this point. ref.set(1.0) self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0) def testTypeChecking(self): ref = ml_collections.FieldReference(1) string_ref = ml_collections.FieldReference('a') x = ref + string_ref with self.assertRaises(TypeError): x.get() def testNoType(self): self.assertRaisesRegex(TypeError, 'field_type should be a type.*', ml_collections.FieldReference, None, 0) def testEqual(self): # Simple case ref1 = ml_collections.FieldReference(1) ref2 = ml_collections.FieldReference(1) ref3 = ml_collections.FieldReference(2) self.assertEqual(ref1, 1) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, 2) self.assertNotEqual(ref1, ref3) # ConfigDict inside FieldReference ref1 = ml_collections.FieldReference(ml_collections.ConfigDict({'a': 1})) ref2 = ml_collections.FieldReference(ml_collections.ConfigDict({'a': 1})) ref3 = ml_collections.FieldReference(ml_collections.ConfigDict({'a': 2})) self.assertEqual(ref1, ml_collections.ConfigDict({'a': 1})) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, ml_collections.ConfigDict({'a': 2})) self.assertNotEqual(ref1, ref3) def testLessEqual(self): # Simple case ref1 = ml_collections.FieldReference(1) ref2 = ml_collections.FieldReference(1) ref3 = ml_collections.FieldReference(2) self.assertLessEqual(ref1, 1) self.assertLessEqual(ref1, 2) self.assertLessEqual(0, ref1) self.assertLessEqual(1, ref1) self.assertGreater(ref1, 0) self.assertLessEqual(ref1, ref1) self.assertLessEqual(ref1, ref2) self.assertLessEqual(ref1, ref3) self.assertGreater(ref3, ref1) def testControlFlowError(self): ref1 = ml_collections.FieldReference(True) ref2 = ml_collections.FieldReference(False) with self.assertRaises(NotImplementedError): if ref1: pass with self.assertRaises(NotImplementedError): _ = ref1 and ref2 with self.assertRaises(NotImplementedError): _ = ref1 or ref2 with self.assertRaises(NotImplementedError): _ = not ref1 if __name__ == '__main__': absltest.main() ml_collections-0.1.1/ml_collections/config_dict/tests/config_dict_test.py0000640000175000017500000013701014174507605026355 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for ml_collections.ConfigDict.""" import abc from collections import abc as collections_abc import functools import json import pickle import sys from absl.testing import absltest from absl.testing import parameterized import ml_collections from ml_collections.config_dict import config_dict import mock import six import yaml _TEST_FIELD = {'int': 0} _TEST_DICT = { 'float': 2.34, 'string': 'tom', 'int': 2, 'list': [1, 2], 'dict': { 'float': -1.23, 'int': 23 }, } def _test_function(): pass # Having ABCMeta as a metaclass shouldn't break yaml serialization. class _TestClass(six.with_metaclass(abc.ABCMeta, object)): def __init__(self): self.variable_1 = 1 self.variable_2 = '2' _test_object = _TestClass() class _TestClassNoStr(): pass _TEST_DICT_BEST_EFFORT = dict(_TEST_DICT) _TEST_DICT_BEST_EFFORT.update({ 'unserializable': _TestClass, 'unserializable_no_str': _TestClassNoStr, 'function': _test_function, 'object': _test_object, 'set': {1, 2, 3} }) # This is how we expect the _TEST_DICT to look after we change the name float to # double using the function configdict.recursive_rename _TEST_DICT_CHANGE_FLOAT_NAME = { 'double': 2.34, 'string': 'tom', 'int': 2, 'list': [1, 2], 'dict': { 'double': -1.23, 'int': 23 }, } def _get_test_dict(): test_dict = dict(_TEST_DICT) field = ml_collections.FieldReference(_TEST_FIELD) test_dict['ref'] = field test_dict['ref2'] = field return test_dict def _get_test_dict_best_effort(): test_dict = dict(_TEST_DICT_BEST_EFFORT) field = ml_collections.FieldReference(_TEST_FIELD) test_dict['ref'] = field test_dict['ref2'] = field return test_dict def _get_test_config_dict(): return ml_collections.ConfigDict(_get_test_dict()) def _get_test_config_dict_best_effort(): return ml_collections.ConfigDict(_get_test_dict_best_effort()) _JSON_TEST_DICT = ('{"dict": {"float": -1.23, "int": 23},' ' "float": 2.34,' ' "int": 2,' ' "list": [1, 2],' ' "ref": {"int": 0},' ' "ref2": {"int": 0},' ' "string": "tom"}') if six.PY2: _DICT_TYPE = "!!python/name:__builtin__.dict ''" _UNSERIALIZABLE_MSG = "unserializable object of type: " else: _DICT_TYPE = "!!python/name:builtins.dict ''" _UNSERIALIZABLE_MSG = ( "unserializable object: ") _TYPES = { 'dict_type': _DICT_TYPE, 'configdict_type': '!!python/object:ml_collections.config_dict.config_dict' '.ConfigDict', 'fieldreference_type': '!!python/object:ml_collections.config_dict' '.config_dict.FieldReference' } _JSON_BEST_EFFORT_TEST_DICT = ( '{"dict": {"float": -1.23, "int": 23},' ' "float": 2.34,' ' "function": "function _test_function",' ' "int": 2,' ' "list": [1, 2],' ' "object": {"variable_1": 1, "variable_2": "2"},' ' "ref": {"int": 0},' ' "ref2": {"int": 0},' ' "set": [1, 2, 3],' ' "string": "tom",' ' "unserializable": "unserializable object: ' '",' ' "unserializable_no_str": "%s"}') % _UNSERIALIZABLE_MSG _REPR_TEST_DICT = """ dict: float: -1.23 int: 23 float: 2.34 int: 2 list: - 1 - 2 ref: &id001 {fieldreference_type} _field_type: {dict_type} _ops: [] _required: false _value: int: 0 ref2: *id001 string: tom """.format(**_TYPES) _STR_TEST_DICT = """ dict: {float: -1.23, int: 23} float: 2.34 int: 2 list: [1, 2] ref: &id001 {int: 0} ref2: *id001 string: tom """ _STR_NESTED_TEST_DICT = """ dict: {float: -1.23, int: 23} float: 2.34 int: 2 list: [1, 2] nested_dict: float: -1.23 int: 23 nested_dict: float: -1.23 int: 23 non_nested_dict: {float: -1.23, int: 23} nested_list: - 1 - 2 - [3, 4, 5] - 6 ref: &id001 {int: 0} ref2: *id001 string: tom """ class ConfigDictTest(parameterized.TestCase): """Tests ConfigDict in config flags library.""" def assertEqualConfigs(self, cfg, dictionary): """Asserts recursive equality of config and a dictionary.""" self.assertEqual(cfg.to_dict(), dictionary) def testCreating(self): """Tests basic config creation.""" cfg = ml_collections.ConfigDict() cfg.field = 2.34 self.assertEqual(cfg.field, 2.34) def testDir(self): """Test that dir() works correctly on config.""" cfg = ml_collections.ConfigDict() cfg.field = 2.34 self.assertIn('field', dir(cfg)) self.assertIn('lock', dir(cfg)) def testFromDictConstruction(self): """Tests creation of config from existing dictionary.""" cfg = ml_collections.ConfigDict(_TEST_DICT) self.assertEqualConfigs(cfg, _TEST_DICT) def testOverridingValues(self): """Tests basic values overriding.""" cfg = ml_collections.ConfigDict() cfg.field = 2.34 self.assertEqual(cfg.field, 2.34) cfg.field = -2.34 self.assertEqual(cfg.field, -2.34) def testDictAttributeTurnsIntoConfigDict(self): """Tests that dicts in a ConfigDict turn to ConfigDicts (recursively).""" cfg = ml_collections.ConfigDict(_TEST_DICT) # Test conversion to dict on creation. self.assertIsInstance(cfg.dict, ml_collections.ConfigDict) # Test conversion to dict on setting attribute. new_dict = {'inside_dict': {'inside_key': 0}} cfg.new_dict = new_dict self.assertIsInstance(cfg.new_dict, ml_collections.ConfigDict) self.assertIsInstance(cfg.new_dict.inside_dict, ml_collections.ConfigDict) self.assertEqual(cfg.new_dict.to_dict(), new_dict) def testOverrideExceptions(self): """Test the `int` and unicode-string exceptions to overriding. ConfigDict forces strong type-checking with two exceptions. The first is that `int` values can be stored to fields of type `float`. And secondly, all string types can be stored in fields of type `str` or `unicode`. """ cfg = ml_collections.ConfigDict() # Test that overriding 'float' fields with int works. cfg.float_field = 2.34 cfg.float_field = 2 self.assertEqual(cfg.float_field, 2.0) # Test that overriding with Unicode strings works. cfg.string_field = '42' cfg.string_field = u'42' self.assertEqual(cfg.string_field, '42') # Test that overriding a Unicode field with a `str` type works. cfg.unicode_string_field = u'42' cfg.unicode_string_field = '42' self.assertEqual(cfg.unicode_string_field, u'42') # Test that overriding a list with a tuple works. cfg.tuple_field = [1, 2, 3] cfg.tuple_field = (1, 2) self.assertEqual(cfg.tuple_field, [1, 2]) # Test that overriding a tuple with a list works. cfg.list_field = [23, 42] cfg.list_field = (8, 9, 10) self.assertEqual(cfg.list_field, [8, 9, 10]) # Test that int <-> long conversions work. int_value = 1 # In Python 2, int(very large number) returns a long long_value = int(1e100) cfg.int_field = int_value cfg.int_field = long_value self.assertEqual(cfg.int_field, long_value) if sys.version_info.major == 2: expected = long else: expected = int self.assertIsInstance(cfg.int_field, expected) cfg.long_field = long_value cfg.long_field = int_value self.assertEqual(cfg.long_field, int_value) self.assertIsInstance(cfg.long_field, expected) def testOverrideCallable(self): """Test that overriding a callable with a callable works.""" class SomeClass: def __init__(self, x, power=1): self.y = x**power def factory(self, x): return SomeClass(self.y + x) fn1 = SomeClass fn2 = lambda x: SomeClass(x, power=2) fn3 = functools.partial(SomeClass, power=3) fn4 = SomeClass(4.0).factory cfg = ml_collections.ConfigDict() for orig in [fn1, fn2, fn3, fn4]: for new in [fn1, fn2, fn3, fn4]: cfg.fn_field = orig cfg.fn_field = new self.assertEqual(cfg.fn_field, new) def testOverrideFieldReference(self): """Test overriding with FieldReference objects.""" cfg = ml_collections.ConfigDict() cfg.field_1 = 'field_1' cfg.field_2 = 'field_2' # Override using a FieldReference. cfg.field_1 = ml_collections.FieldReference('override_1') # Override FieldReference field using another FieldReference. cfg.field_1 = ml_collections.FieldReference('override_2') # Override using empty FieldReference. cfg.field_2 = ml_collections.FieldReference(None, field_type=str) # Override FieldReference field using string. cfg.field_2 = 'field_2' # Check a TypeError is raised when using FieldReference's with wrong type. with self.assertRaises(TypeError): cfg.field_2 = ml_collections.FieldReference(1) with self.assertRaises(TypeError): cfg.field_2 = ml_collections.FieldReference(None, field_type=int) def testTypeSafe(self): """Tests type safe checking.""" cfg = _get_test_config_dict() with self.assertRaisesRegex(TypeError, 'field \'float\''): cfg.float = 'tom' # Test that float cannot be assigned to int. with self.assertRaisesRegex(TypeError, 'field \'int\''): cfg.int = 12.8 with self.assertRaisesRegex(TypeError, 'field \'string\''): cfg.string = -123 with self.assertRaisesRegex(TypeError, 'field \'float\''): cfg.dict.float = 'string' # Ensure None is ignored by type safety cfg.string = None cfg.string = 'tom' def testIgnoreType(self): cfg = ml_collections.ConfigDict({ 'string': 'This is a string', 'float': 3.0, 'list': [ml_collections.ConfigDict({'float': 1.0})], 'tuple': [ml_collections.ConfigDict({'float': 1.0})], 'dict': { 'float': 1.0 } }) with cfg.ignore_type(): cfg.string = -123 cfg.float = 'string' cfg.list[0].float = 'string' cfg.tuple[0].float = 'string' cfg.dict.float = 'string' def testTypeUnsafe(self): """Tests lack of type safe checking.""" cfg = ml_collections.ConfigDict(_get_test_dict(), type_safe=False) cfg.float = 'tom' cfg.string = -123 cfg.int = 12.8 def testLocking(self): """Tests lock mechanism.""" cfg = ml_collections.ConfigDict() cfg.field = 2 cfg.dict_field = {'float': 1.23, 'integer': 3} cfg.ref = ml_collections.FieldReference( ml_collections.ConfigDict({'integer': 0})) cfg.lock() cfg.field = -4 with self.assertRaises(AttributeError): cfg.new_field = 2 with self.assertRaises(AttributeError): cfg.dict_field.new_field = -1.23 with self.assertRaises(AttributeError): cfg.ref.new_field = 1 with self.assertRaises(AttributeError): del cfg.field def testUnlocking(self): """Tests unlock mechanism.""" cfg = ml_collections.ConfigDict() cfg.dict_field = {'float': 1.23, 'integer': 3} cfg.ref = ml_collections.FieldReference( ml_collections.ConfigDict({'integer': 0})) cfg.lock() with cfg.unlocked(): cfg.new_field = 2 cfg.dict_field.new_field = -1.23 cfg.ref.new_field = 1 def testGetMethod(self): """Tests get method.""" cfg = _get_test_config_dict() self.assertEqual(cfg.get('float', -1), cfg.float) self.assertEqual(cfg.get('ref', -1), cfg.ref) self.assertEqual(cfg.get('another_key', -1), -1) self.assertIsNone(cfg.get('another_key')) def testItemsMethod(self): """Tests items method.""" cfg = _get_test_config_dict() self.assertEqual(dict(**cfg), dict(cfg.items())) items = cfg.items() self.assertEqual(len(items), len(_get_test_dict())) for entry in _TEST_DICT.items(): if isinstance(entry[1], dict): entry = (entry[0], ml_collections.ConfigDict(entry[1])) self.assertIn(entry, items) self.assertIn(('ref', cfg.ref), items) self.assertIn(('ref2', cfg.ref2), items) ind_ref = items.index(('ref', cfg.ref)) ind_ref2 = items.index(('ref2', cfg.ref2)) self.assertIs(items[ind_ref][1], items[ind_ref2][1]) cfg = ml_collections.ConfigDict() self.assertEqual(dict(**cfg), dict(cfg.items())) # Test that items are sorted self.assertEqual(sorted(dict(**cfg).items()), cfg.items()) def testGetItemRecursively(self): """Tests getting items recursively (e.g., config['a.b']).""" cfg = _get_test_config_dict() self.assertEqual(cfg['dict.float'], -1.23) self.assertEqual('%(dict.int)i' % cfg, '23') def testIterItemsMethod(self): """Tests iteritems method.""" cfg = _get_test_config_dict() self.assertEqual(dict(**cfg), dict(cfg.iteritems())) cfg = ml_collections.ConfigDict() self.assertEqual(dict(**cfg), dict(cfg.iteritems())) def testIterKeysMethod(self): """Tests iterkeys method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertEqual(set(six.iterkeys(some_dict)), set(six.iterkeys(cfg))) # Test that keys are sorted for k_ref, k in zip(sorted(six.iterkeys(cfg)), six.iterkeys(cfg)): self.assertEqual(k_ref, k) def testKeysMethod(self): """Tests keys method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertEqual(set(some_dict.keys()), set(cfg.keys())) # Test that keys are sorted for k_ref, k in zip(sorted(cfg.keys()), cfg.keys()): self.assertEqual(k_ref, k) def testLenMethod(self): """Tests keys method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertLen(cfg, len(some_dict)) def testIterValuesMethod(self): """Tests itervalues method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertEqual(set(six.itervalues(some_dict)), set(six.itervalues(cfg))) # Test that items are sorted for k_ref, v in zip(sorted(six.iterkeys(cfg)), six.itervalues(cfg)): self.assertEqual(cfg[k_ref], v) def testValuesMethod(self): """Tests values method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertEqual(set(some_dict.values()), set(cfg.values())) # Test that items are sorted for k_ref, v in zip(sorted(cfg.keys()), cfg.values()): self.assertEqual(cfg[k_ref], v) def testIterValuesResolvesReferences(self): """Tests itervalues FieldReference resolution.""" cfg = ml_collections.ConfigDict({'x1': 32, 'x2': 5.2, 'x3': 'str'}) ref = ml_collections.FieldReference(0) cfg['x4'] = ref for v in cfg.itervalues(): self.assertNotIsInstance(v, ml_collections.FieldReference) self.assertIn(ref, cfg.itervalues(preserve_field_references=True)) def testValuesResolvesReferences(self): """Tests values FieldReference resolution.""" cfg = ml_collections.ConfigDict({'x1': 32, 'x2': 5.2, 'x3': 'str'}) ref = ml_collections.FieldReference(0) cfg['x4'] = ref for v in cfg.values(): self.assertNotIsInstance(v, ml_collections.FieldReference) self.assertIn(ref, cfg.values(preserve_field_references=True)) def testIterItemsResolvesReferences(self): """Tests iteritems FieldReference resolution.""" cfg = ml_collections.ConfigDict({'x1': 32, 'x2': 5.2, 'x3': 'str'}) ref = ml_collections.FieldReference(0) cfg['x4'] = ref for _, v in cfg.iteritems(): self.assertNotIsInstance(v, ml_collections.FieldReference) self.assertIn(('x4', ref), cfg.iteritems(preserve_field_references=True)) def testItemsResolvesReferences(self): """Tests items FieldReference resolution.""" cfg = ml_collections.ConfigDict({'x1': 32, 'x2': 5.2, 'x3': 'str'}) ref = ml_collections.FieldReference(0) cfg['x4'] = ref for _, v in cfg.items(): self.assertNotIsInstance(v, ml_collections.FieldReference) self.assertIn(('x4', ref), cfg.items(preserve_field_references=True)) def testEquals(self): """Tests __eq__ and __ne__ methods.""" some_dict = { 'float': 1.23, 'integer': 3, 'list': [1, 2], 'dict': { 'a': {}, 'b': 'string' } } cfg = ml_collections.ConfigDict(some_dict) cfg_other = ml_collections.ConfigDict(some_dict) self.assertEqual(cfg, cfg_other) self.assertEqual(ml_collections.ConfigDict(cfg), cfg_other) cfg_other.float = 3 self.assertNotEqual(cfg, cfg_other) cfg_other.float = cfg.float cfg_other.list = ['a', 'b'] self.assertNotEqual(cfg, cfg_other) cfg_other.list = cfg.list cfg_other.lock() self.assertNotEqual(cfg, cfg_other) cfg_other.unlock() cfg_other = ml_collections.ConfigDict(some_dict, type_safe=False) self.assertNotEqual(cfg, cfg_other) cfg = ml_collections.ConfigDict(some_dict) # References that have the same id should be equal (even if self-references) cfg_other = ml_collections.ConfigDict(some_dict) cfg_other.me = cfg cfg.me = cfg self.assertEqual(cfg, cfg_other) cfg = ml_collections.ConfigDict(some_dict) cfg.me = cfg self.assertEqual(cfg, cfg) # Self-references that do not have the same id loop infinitely cfg_other = ml_collections.ConfigDict(some_dict) cfg_other.me = cfg_other # Temporarily disable coverage trace while testing runtime is exceeded trace_func = sys.gettrace() sys.settrace(None) with self.assertRaises(RuntimeError): cfg == cfg_other # pylint:disable=pointless-statement sys.settrace(trace_func) def testEqAsConfigDict(self): """Tests .eq_as_configdict() method.""" cfg_1 = _get_test_config_dict() cfg_2 = _get_test_config_dict() cfg_2.added_field = 3.14159 cfg_self_ref = _get_test_config_dict() cfg_self_ref.self_ref = cfg_self_ref frozen_cfg_1 = ml_collections.FrozenConfigDict(cfg_1) frozen_cfg_2 = ml_collections.FrozenConfigDict(cfg_2) self.assertTrue(cfg_1.eq_as_configdict(cfg_1)) self.assertTrue(cfg_1.eq_as_configdict(frozen_cfg_1)) self.assertTrue(frozen_cfg_1.eq_as_configdict(cfg_1)) self.assertTrue(frozen_cfg_1.eq_as_configdict(frozen_cfg_1)) self.assertFalse(cfg_1.eq_as_configdict(cfg_2)) self.assertFalse(cfg_1.eq_as_configdict(frozen_cfg_2)) self.assertFalse(frozen_cfg_1.eq_as_configdict(cfg_self_ref)) self.assertFalse(frozen_cfg_1.eq_as_configdict(frozen_cfg_2)) self.assertFalse(cfg_self_ref.eq_as_configdict(cfg_1)) def testHash(self): some_dict = {'float': 1.23, 'integer': 3} cfg = ml_collections.ConfigDict(some_dict) with self.assertRaisesRegex(TypeError, 'unhashable type'): hash(cfg) # Ensure Python realizes ConfigDict is not hashable. self.assertNotIsInstance(cfg, collections_abc.Hashable) def testDidYouMeanFeature(self): """Tests 'did you mean' suggestions.""" cfg = ml_collections.ConfigDict() cfg.learning_rate = 0.01 cfg.lock() with self.assertRaisesRegex(AttributeError, 'Did you mean.*learning_rate.*'): _ = cfg.laerning_rate with cfg.unlocked(): with self.assertRaisesRegex(AttributeError, 'Did you mean.*learning_rate.*'): del cfg.laerning_rate with self.assertRaisesRegex(AttributeError, 'Did you mean.*learning_rate.*'): cfg.laerning_rate = 0.02 self.assertEqual(cfg.learning_rate, 0.01) with self.assertRaises(AttributeError): _ = self.laerning_rate def testReferences(self): """Tests assigning references in the dict.""" cfg = _get_test_config_dict() cfg.dict_ref = cfg.dict self.assertEqual(cfg.dict_ref, cfg.dict) def testPreserveReferences(self): """Tests that initializing with another ConfigDict preserves references.""" cfg = _get_test_config_dict() # In the original, "ref" and "ref2" are the same FieldReference self.assertIs(cfg.get_ref('ref'), cfg.get_ref('ref2')) # Create a copy from the original cfg2 = ml_collections.ConfigDict(cfg) # If the refs had not been preserved, get_ref would create a new # reference for each call self.assertIs(cfg2.get_ref('ref'), cfg2.get_ref('ref2')) self.assertIs(cfg2.ref, cfg2.ref2) # the values are also the same object def testUnpacking(self): """Tests ability to pass ConfigDict instance with ** operator.""" cfg = ml_collections.ConfigDict() cfg.x = 2 def foo(x): return x + 3 self.assertEqual(foo(**cfg), 5) def testUnpackingWithFieldReference(self): """Tests ability to pass ConfigDict instance with ** operator.""" cfg = ml_collections.ConfigDict() cfg.x = ml_collections.FieldReference(2) def foo(x): return x + 3 self.assertEqual(foo(**cfg), 5) def testReadingIncorrectField(self): """Tests whether accessing non-existing fields raises an exception.""" cfg = ml_collections.ConfigDict() with self.assertRaises(AttributeError): _ = cfg.non_existing_field with self.assertRaises(KeyError): _ = cfg['non_existing_field'] def testIteration(self): """Tests whether one can iterate over ConfigDict.""" cfg = ml_collections.ConfigDict() for i in range(10): cfg['field{}'.format(i)] = 'field{}'.format(i) for field in cfg: self.assertEqual(cfg[field], getattr(cfg, field)) def testDeYaml(self): """Tests YAML deserialization.""" cfg = _get_test_config_dict() deyamled = yaml.load(cfg.to_yaml(), yaml.UnsafeLoader) self.assertEqual(cfg, deyamled) def testJSONConversion(self): """Tests JSON serialization.""" cfg = _get_test_config_dict() self.assertEqual( cfg.to_json(sort_keys=True).strip(), _JSON_TEST_DICT.strip()) cfg = _get_test_config_dict_best_effort() with self.assertRaises(TypeError): cfg.to_json() def testJSONConversionCustomEncoder(self): """Tests JSON serialization with custom encoder.""" cfg = _get_test_config_dict() encoder = json.JSONEncoder() mock_encoder_cls = mock.MagicMock() mock_encoder_cls.return_value = encoder with mock.patch.object(encoder, 'default') as mock_default: mock_default.return_value = '' cfg.to_json(json_encoder_cls=mock_encoder_cls) mock_default.assert_called() def testJSONConversionBestEffort(self): """Tests JSON serialization.""" # Check that best effort option doesn't break default functionality cfg = _get_test_config_dict() self.assertEqual( cfg.to_json_best_effort(sort_keys=True).strip(), _JSON_TEST_DICT.strip()) cfg_best_effort = _get_test_config_dict_best_effort() self.assertEqual( cfg_best_effort.to_json_best_effort(sort_keys=True).strip(), _JSON_BEST_EFFORT_TEST_DICT.strip()) def testReprConversion(self): """Tests repr conversion.""" cfg = _get_test_config_dict() self.assertEqual(repr(cfg).strip(), _REPR_TEST_DICT.strip()) def testLoadFromRepr(self): cfg_dict = ml_collections.ConfigDict() field = ml_collections.FieldReference(1) cfg_dict.r1 = field cfg_dict.r2 = field cfg_load = yaml.load(repr(cfg_dict), yaml.UnsafeLoader) # Test FieldReferences are preserved cfg_load['r1'].set(2) self.assertEqual(cfg_load['r1'].get(), cfg_load['r2'].get()) def testStrConversion(self): """Tests conversion to str.""" cfg = _get_test_config_dict() # Verify srt(cfg) doesn't raise errors. _ = str(cfg) test_dict_2 = _get_test_dict() test_dict_2['nested_dict'] = { 'float': -1.23, 'int': 23, 'nested_dict': { 'float': -1.23, 'int': 23, 'non_nested_dict': { 'float': -1.23, 'int': 233, }, }, 'nested_list': [1, 2, [3, 44, 5], 6], } cfg_2 = ml_collections.ConfigDict(test_dict_2) # Demonstrate that dot-access works. cfg_2.nested_dict.nested_dict.non_nested_dict.int = 23 cfg_2.nested_dict.nested_list[2][1] = 4 # Verify srt(cfg) doesn't raise errors. _ = str(cfg_2) def testDotInField(self): """Tests trying to create a dot containing field.""" cfg = ml_collections.ConfigDict() with self.assertRaises(ValueError): cfg['invalid.name'] = 2.3 def testToDictConversion(self): """Tests whether references are correctly handled when calling to_dict.""" cfg = ml_collections.ConfigDict() field = ml_collections.FieldReference('a string') cfg.dict = { 'float': 2.3, 'integer': 1, 'field_ref1': field, 'field_ref2': field } cfg.ref = cfg.dict cfg.self_ref = cfg pure_dict = cfg.to_dict() self.assertEqual(type(pure_dict), dict) self.assertIs(pure_dict, pure_dict['self_ref']) self.assertIs(pure_dict['dict'], pure_dict['ref']) # Ensure ConfigDict has been converted to dict. self.assertEqual(type(pure_dict['dict']), dict) # Ensure FieldReferences are not preserved, by default. self.assertNotIsInstance(pure_dict['dict']['field_ref1'], ml_collections.FieldReference) self.assertNotIsInstance(pure_dict['dict']['field_ref2'], ml_collections.FieldReference) self.assertEqual(pure_dict['dict']['field_ref1'], field.get()) self.assertEqual(pure_dict['dict']['field_ref2'], field.get()) pure_dict_with_refs = cfg.to_dict(preserve_field_references=True) self.assertEqual(type(pure_dict_with_refs), dict) self.assertEqual(type(pure_dict_with_refs['dict']), dict) self.assertIsInstance(pure_dict_with_refs['dict']['field_ref1'], ml_collections.FieldReference) self.assertIsInstance(pure_dict_with_refs['dict']['field_ref2'], ml_collections.FieldReference) self.assertIs(pure_dict_with_refs['dict']['field_ref1'], pure_dict_with_refs['dict']['field_ref2']) # Ensure FieldReferences in the dict are not the same as the FieldReferences # in the original ConfigDict. self.assertIsNot(pure_dict_with_refs['dict']['field_ref1'], cfg.dict['field_ref1']) def testToDictTypeUnsafe(self): """Tests interaction between ignore_type() and to_dict().""" cfg = ml_collections.ConfigDict() cfg.string = ml_collections.FieldReference(None, field_type=str) with cfg.ignore_type(): cfg.string = 1 self.assertEqual(1, cfg.to_dict(preserve_field_references=True)['string']) def testCopyAndResolveReferences(self): """Tests the .copy_and_resolve_references() method.""" cfg = ml_collections.ConfigDict() field = ml_collections.FieldReference('a string') int_field = ml_collections.FieldReference(5) cfg.dict = { 'float': 2.3, 'integer': 1, 'field_ref1': field, 'field_ref2': field, 'field_ref_int1': int_field, 'field_ref_int2': int_field + 5, 'placeholder': config_dict.placeholder(str), 'cfg': ml_collections.ConfigDict({ 'integer': 1, 'int_field': int_field }) } cfg.ref = cfg.dict cfg.self_ref = cfg cfg_resolved = cfg.copy_and_resolve_references() for field, value in [('float', 2.3), ('integer', 1), ('field_ref1', 'a string'), ('field_ref2', 'a string'), ('field_ref_int1', 5), ('field_ref_int2', 10), ('placeholder', None)]: self.assertEqual(getattr(cfg_resolved.dict, field), value) for field, value in [('integer', 1), ('int_field', 5)]: self.assertEqual(getattr(cfg_resolved.dict.cfg, field), value) self.assertIs(cfg_resolved, cfg_resolved['self_ref']) self.assertIs(cfg_resolved['dict'], cfg_resolved['ref']) def testCopyAndResolveReferencesConfigTypes(self): """Tests that .copy_and_resolve_references() handles special types.""" cfg_type_safe = ml_collections.ConfigDict() int_field = ml_collections.FieldReference(5) cfg_type_safe.field_ref1 = int_field cfg_type_safe.field_ref2 = int_field + 5 cfg_type_safe.lock() cfg_type_safe_locked_resolved = cfg_type_safe.copy_and_resolve_references() self.assertTrue(cfg_type_safe_locked_resolved.is_locked) self.assertTrue(cfg_type_safe_locked_resolved.is_type_safe) cfg = ml_collections.ConfigDict(type_safe=False) cfg.field_ref1 = int_field cfg.field_ref2 = int_field + 5 cfg_resolved = cfg.copy_and_resolve_references() self.assertFalse(cfg_resolved.is_locked) self.assertFalse(cfg_resolved.is_type_safe) cfg.lock() cfg_locked_resolved = cfg.copy_and_resolve_references() self.assertTrue(cfg_locked_resolved.is_locked) self.assertFalse(cfg_locked_resolved.is_type_safe) for resolved in [ cfg_type_safe_locked_resolved, cfg_resolved, cfg_locked_resolved ]: self.assertEqual(resolved.field_ref1, 5) self.assertEqual(resolved.field_ref2, 10) frozen_cfg = ml_collections.FrozenConfigDict(cfg_type_safe) frozen_cfg_resolved = frozen_cfg.copy_and_resolve_references() for resolved in [frozen_cfg, frozen_cfg_resolved]: self.assertEqual(resolved.field_ref1, 5) self.assertEqual(resolved.field_ref2, 10) self.assertIsInstance(resolved, ml_collections.FrozenConfigDict) def testInitConfigDict(self): """Tests initializing a ConfigDict on a ConfigDict.""" cfg = _get_test_config_dict() cfg_2 = ml_collections.ConfigDict(cfg) self.assertIsNot(cfg_2, cfg) self.assertIs(cfg_2.float, cfg.float) self.assertIs(cfg_2.dict, cfg.dict) # Ensure ConfigDict fields are initialized as is dict_with_cfg_field = {'cfg': cfg} cfg_3 = ml_collections.ConfigDict(dict_with_cfg_field) self.assertIs(cfg_3.cfg, cfg) # Now ensure it works with locking and type_safe cfg_4 = ml_collections.ConfigDict(cfg, type_safe=False) cfg_4.lock() self.assertEqual(cfg_4, ml_collections.ConfigDict(cfg_4)) def testInitReferenceStructure(self): """Ensures initialization preserves reference structure.""" x = [1, 2, 3] self_ref_dict = { 'float': 2.34, 'test_dict_1': _TEST_DICT, 'test_dict_2': _TEST_DICT, 'list': x } self_ref_dict['self'] = self_ref_dict self_ref_dict['self_fr'] = ml_collections.FieldReference(self_ref_dict) self_ref_cd = ml_collections.ConfigDict(self_ref_dict) self.assertIs(self_ref_cd.test_dict_1, self_ref_cd.test_dict_2) self.assertIs(self_ref_cd, self_ref_cd.self) self.assertIs(self_ref_cd, self_ref_cd.self_fr) self.assertIs(self_ref_cd.list, x) self.assertEqual(self_ref_cd, self_ref_cd.self) self_ref_cd.self.int = 1 self.assertEqual(self_ref_cd.int, 1) self_ref_cd_2 = ml_collections.ConfigDict(self_ref_cd) self.assertIsNot(self_ref_cd_2, self_ref_cd) self.assertIs(self_ref_cd_2.self, self_ref_cd_2) self.assertIs(self_ref_cd_2.test_dict_1, self_ref_cd.test_dict_1) def testInitFieldReference(self): """Tests initialization with FieldReferences.""" test_dict = dict(x=1, y=1) # Reference to a dict. reference = ml_collections.FieldReference(test_dict) cfg = ml_collections.ConfigDict() cfg.reference = reference self.assertIsInstance(cfg.reference, ml_collections.ConfigDict) self.assertEqual(test_dict['x'], cfg.reference.x) self.assertEqual(test_dict['y'], cfg.reference.y) # Reference to a ConfigDict. test_configdict = ml_collections.ConfigDict(test_dict) reference = ml_collections.FieldReference(test_configdict) cfg = ml_collections.ConfigDict() cfg.reference = reference test_configdict.x = 2 self.assertEqual(test_configdict.x, cfg.reference.x) self.assertEqual(test_configdict.y, cfg.reference.y) # Reference to a reference. reference_int = ml_collections.FieldReference(0) reference = ml_collections.FieldReference(reference_int) cfg = ml_collections.ConfigDict() cfg.reference = reference reference_int.set(1) self.assertEqual(reference_int.get(), cfg.reference) def testDeletingFields(self): """Tests whether it is possible to delete fields.""" cfg = ml_collections.ConfigDict() cfg.field1 = 123 cfg.field2 = 123 self.assertIn('field1', cfg) self.assertIn('field2', cfg) del cfg.field1 self.assertNotIn('field1', cfg) self.assertIn('field2', cfg) del cfg.field2 self.assertNotIn('field2', cfg) with self.assertRaises(AttributeError): del cfg.keys with self.assertRaises(KeyError): del cfg['keys'] def testDeletingNestedFields(self): """Tests whether it is possible to delete nested fields.""" cfg = ml_collections.ConfigDict({ 'a': { 'aa': [1, 2], }, 'b': { 'ba': { 'baa': 2, 'bab': 3, }, 'bb': {1, 2, 3}, }, }) self.assertIn('a', cfg) self.assertIn('aa', cfg.a) self.assertIn('baa', cfg.b.ba) del cfg['a.aa'] self.assertIn('a', cfg) self.assertNotIn('aa', cfg.a) del cfg['a'] self.assertNotIn('a', cfg) del cfg['b.ba.baa'] self.assertIn('ba', cfg.b) self.assertIn('bab', cfg.b.ba) self.assertNotIn('baa', cfg.b.ba) del cfg['b.ba'] self.assertNotIn('ba', cfg.b) self.assertIn('bb', cfg.b) with self.assertRaises(AttributeError): del cfg.keys with self.assertRaises(KeyError): del cfg['keys'] def testSetAttr(self): """Tests whether it is possible to override an attribute.""" cfg = ml_collections.ConfigDict() with self.assertRaises(AttributeError): cfg.__setattr__('__class__', 'abc') def testPickling(self): """Tests whether ConfigDict can be pickled and unpickled.""" cfg = _get_test_config_dict() cfg.lock() pickle_cfg = pickle.loads(pickle.dumps(cfg)) self.assertTrue(pickle_cfg.is_locked) self.assertIsInstance(pickle_cfg, ml_collections.ConfigDict) self.assertEqual(str(cfg), str(pickle_cfg)) def testPlaceholder(self): """Tests whether FieldReference works correctly as a placeholder.""" cfg_element = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict({ 'element': cfg_element, 'nested': { 'element': cfg_element } }) # Type mismatch. with self.assertRaises(TypeError): cfg.element = 'string' cfg.element = 1 self.assertEqual(cfg.element, cfg.nested.element) def testOptional(self): """Tests whether FieldReference works correctly as an optional field.""" # Type mismatch at construction. with self.assertRaises(TypeError): ml_collections.FieldReference(0, field_type=str) # None default and field_type. with self.assertRaises(ValueError): ml_collections.FieldReference(None) cfg = ml_collections.ConfigDict({ 'default': ml_collections.FieldReference(0), }) cfg.default = 1 self.assertEqual(cfg.default, 1) def testOptionalNoDefault(self): """Tests optional field with no default value.""" cfg = ml_collections.ConfigDict({ 'nodefault': ml_collections.FieldReference(None, field_type=str), }) # Type mismatch with field with no default value. with self.assertRaises(TypeError): cfg.nodefault = 1 cfg.nodefault = 'string' self.assertEqual(cfg.nodefault, 'string') def testGetType(self): """Tests whether types are correct for FieldReference fields.""" cfg = ml_collections.ConfigDict() cfg.integer = 123 cfg.ref = ml_collections.FieldReference(123) cfg.ref_nodefault = ml_collections.FieldReference(None, field_type=int) self.assertEqual(cfg.get_type('integer'), int) self.assertEqual(cfg.get_type('ref'), int) self.assertEqual(cfg.get_type('ref_nodefault'), int) # Check errors in case of misspelled key. with self.assertRaisesRegex(AttributeError, 'Did you.*ref_nodefault.*'): cfg.get_type('ref_nodefualt') with self.assertRaisesRegex(AttributeError, 'Did you.*integer.*'): cfg.get_type('integre') class ConfigDictUpdateTest(absltest.TestCase): def testUpdateSimple(self): """Tests updating from one ConfigDict to another.""" first = ml_collections.ConfigDict() first.x = 5 first.y = 'truman' first.q = 2.0 second = ml_collections.ConfigDict() second.x = 9 second.y = 'wilson' second.z = 'washington' first.update(second) self.assertEqual(first.x, 9) self.assertEqual(first.y, 'wilson') self.assertEqual(first.z, 'washington') self.assertEqual(first.q, 2.0) def testUpdateNothing(self): """Tests updating a ConfigDict with no arguments.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 cfg.update() self.assertLen(cfg, 2) self.assertEqual(cfg.x, 5) self.assertEqual(cfg.y, 9) def testUpdateFromDict(self): """Tests updating a ConfigDict from a dict.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 cfg.update({'x': 6, 'z': 2}) self.assertEqual(cfg.x, 6) self.assertEqual(cfg.y, 9) self.assertEqual(cfg.z, 2) def testUpdateFromKwargs(self): """Tests updating a ConfigDict from kwargs.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 cfg.update(x=6, z=2) self.assertEqual(cfg.x, 6) self.assertEqual(cfg.y, 9) self.assertEqual(cfg.z, 2) def testUpdateFromDictAndKwargs(self): """Tests updating a ConfigDict from a dict and kwargs.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 cfg.update({'x': 4, 'z': 2}, x=6) self.assertEqual(cfg.x, 6) # kwarg overrides value from dict self.assertEqual(cfg.y, 9) self.assertEqual(cfg.z, 2) def testUpdateFromMultipleDictTypeError(self): """Tests that updating a ConfigDict from two dicts raises a TypeError.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 with self.assertRaisesRegex(TypeError, 'update expected at most 1 arguments, got 2'): cfg.update({'x': 4}, {'z': 2}) def testUpdateNested(self): """Tests updating a ConfigDict from a nested dict.""" cfg = ml_collections.ConfigDict() cfg.subcfg = ml_collections.ConfigDict() cfg.p = 5 cfg.q = 6 cfg.subcfg.y = 9 cfg.update({'p': 4, 'subcfg': {'y': 10, 'z': 5}}) self.assertEqual(cfg.p, 4) self.assertEqual(cfg.q, 6) self.assertEqual(cfg.subcfg.y, 10) self.assertEqual(cfg.subcfg.z, 5) def _assert_associated(self, cfg1, cfg2, key): self.assertEqual(cfg1[key], cfg2[key]) cfg1[key] = 1 cfg2[key] = 2 self.assertEqual(cfg1[key], 2) cfg1[key] = 3 self.assertEqual(cfg2[key], 3) def testUpdateFieldReference(self): """Tests updating to/from FieldReference fields.""" # Updating FieldReference... ref = ml_collections.FieldReference(1) cfg = ml_collections.ConfigDict(dict(a=ref, b=ref)) # from value. cfg.update(ml_collections.ConfigDict(dict(a=2))) self.assertEqual(cfg.a, 2) self.assertEqual(cfg.b, 2) # from FieldReference. error_message = 'Cannot update a FieldReference from another FieldReference' with self.assertRaisesRegex(TypeError, error_message): cfg.update( ml_collections.ConfigDict(dict(a=ml_collections.FieldReference(2)))) with self.assertRaisesRegex(TypeError, error_message): cfg.update( ml_collections.ConfigDict(dict(b=ml_collections.FieldReference(2)))) # Updating empty ConfigDict with FieldReferences. ref = ml_collections.FieldReference(1) cfg_from = ml_collections.ConfigDict(dict(a=ref, b=ref)) cfg = ml_collections.ConfigDict() cfg.update(cfg_from) self._assert_associated(cfg, cfg_from, 'a') self._assert_associated(cfg, cfg_from, 'b') # Updating values with FieldReferences. ref = ml_collections.FieldReference(1) cfg_from = ml_collections.ConfigDict(dict(a=ref, b=ref)) cfg = ml_collections.ConfigDict(dict(a=2, b=3)) cfg.update(cfg_from) self._assert_associated(cfg, cfg_from, 'a') self._assert_associated(cfg, cfg_from, 'b') def testUpdateFromFlattened(self): cfg = ml_collections.ConfigDict({'a': 1, 'b': {'c': {'d': 2}}}) updates = {'a': 2, 'b.c.d': 3} cfg.update_from_flattened_dict(updates) self.assertEqual(cfg.a, 2) self.assertEqual(cfg.b.c.d, 3) def testUpdateFromFlattenedWithPrefix(self): cfg = ml_collections.ConfigDict({'a': 1, 'b': {'c': {'d': 2}}}) updates = {'a': 2, 'b.c.d': 3} cfg.b.update_from_flattened_dict(updates, 'b.') self.assertEqual(cfg.a, 1) self.assertEqual(cfg.b.c.d, 3) def testUpdateFromFlattenedNotFound(self): cfg = ml_collections.ConfigDict({'a': 1, 'b': {'c': {'d': 2}}}) updates = {'a': 2, 'b.d.e': 3} with self.assertRaisesRegex( KeyError, 'Key "b.d.e" cannot be set as "b.d" was not found.'): cfg.update_from_flattened_dict(updates) def testUpdateFromFlattenedWrongType(self): cfg = ml_collections.ConfigDict({'a': 1, 'b': {'c': {'d': 2}}}) updates = {'a.b.c': 2} with self.assertRaisesRegex( KeyError, 'Key "a.b.c" cannot be updated as "a" is not a ConfigDict.'): cfg.update_from_flattened_dict(updates) def testUpdateFromFlattenedTupleListConversion(self): cfg = ml_collections.ConfigDict({ 'a': 1, 'b': { 'c': { 'd': (1, 2, 3, 4, 5), } } }) updates = { 'b.c.d': [2, 4, 6, 8], } cfg.update_from_flattened_dict(updates) self.assertIsInstance(cfg.b.c.d, tuple) self.assertEqual(cfg.b.c.d, (2, 4, 6, 8)) def testDecodeError(self): # ConfigDict containing two strings with incompatible encodings. cfg = ml_collections.ConfigDict({ 'dill': pickle.dumps(_test_function, protocol=pickle.HIGHEST_PROTOCOL), 'unicode': u'unicode string' }) expected_error = config_dict.JSONDecodeError if six.PY2 else TypeError with self.assertRaises(expected_error): cfg.to_json() def testConvertDict(self): """Test automatic conversion, or not, of dict to ConfigDict.""" cfg = ml_collections.ConfigDict() cfg.a = dict(b=dict(c=0)) self.assertIsInstance(cfg.a, ml_collections.ConfigDict) self.assertIsInstance(cfg.a.b, ml_collections.ConfigDict) cfg = ml_collections.ConfigDict(convert_dict=False) cfg.a = dict(b=dict(c=0)) self.assertNotIsInstance(cfg.a, ml_collections.ConfigDict) self.assertIsInstance(cfg.a, dict) self.assertIsInstance(cfg.a['b'], dict) def testConvertDictInInitialValue(self): """Test automatic conversion, or not, of dict to ConfigDict.""" initial_dict = dict(a=dict(b=dict(c=0))) cfg = ml_collections.ConfigDict(initial_dict) self.assertIsInstance(cfg.a, ml_collections.ConfigDict) self.assertIsInstance(cfg.a.b, ml_collections.ConfigDict) cfg = ml_collections.ConfigDict(initial_dict, convert_dict=False) self.assertNotIsInstance(cfg.a, ml_collections.ConfigDict) self.assertIsInstance(cfg.a, dict) self.assertIsInstance(cfg.a['b'], dict) def testConvertDictInCopyAndResolveReferences(self): """Test conversion, or not, of dict in copy and resolve references.""" cfg = ml_collections.ConfigDict() cfg.a = dict(b=dict(c=0)) copied_cfg = cfg.copy_and_resolve_references() self.assertIsInstance(copied_cfg.a, ml_collections.ConfigDict) self.assertIsInstance(copied_cfg.a.b, ml_collections.ConfigDict) cfg = ml_collections.ConfigDict(convert_dict=False) cfg.a = dict(b=dict(c=0)) copied_cfg = cfg.copy_and_resolve_references() self.assertNotIsInstance(copied_cfg.a, ml_collections.ConfigDict) self.assertIsInstance(copied_cfg.a, dict) self.assertIsInstance(copied_cfg.a['b'], dict) def testConvertDictTypeCompat(self): """Test that automatic conversion to ConfigDict doesn't trigger type errors.""" cfg = ml_collections.ConfigDict() cfg.a = {} self.assertIsInstance(cfg.a, ml_collections.ConfigDict) # This checks that dict to configdict casting doesn't produce type mismatch. cfg.a = {} def testYamlNoConvert(self): """Test deserialisation from YAML without convert dict. This checks backward compatibility of deserialisation. """ cfg = ml_collections.ConfigDict(dict(a=1)) self.assertTrue(yaml.load(cfg.to_yaml(), yaml.UnsafeLoader)._convert_dict) def testRecursiveRename(self): """Test recursive_rename. The dictionary should be the same but with the specified name changed. """ cfg = ml_collections.ConfigDict(_TEST_DICT) new_cfg = config_dict.recursive_rename(cfg, 'float', 'double') # Check that the new config has float changed to double as we expect self.assertEqual(new_cfg.to_dict(), _TEST_DICT_CHANGE_FLOAT_NAME) # Check that the original config is unchanged self.assertEqual(cfg.to_dict(), _TEST_DICT) def testGetOnewayRef(self): cfg = config_dict.create(a=1) cfg.b = cfg.get_oneway_ref('a') cfg.a = 2 self.assertEqual(2, cfg.b) cfg.b = 3 self.assertEqual(2, cfg.a) self.assertEqual(3, cfg.b) class CreateTest(absltest.TestCase): def testBasic(self): config = config_dict.create(a=1, b='b') dct = {'a': 1, 'b': 'b'} self.assertEqual(config.to_dict(), dct) def testNested(self): config = config_dict.create( data=config_dict.create(game='freeway'), model=config_dict.create(num_hidden=1000)) dct = {'data': {'game': 'freeway'}, 'model': {'num_hidden': 1000}} self.assertEqual(config.to_dict(), dct) class PlaceholderTest(absltest.TestCase): def testBasic(self): config = config_dict.create(a=1, b=config_dict.placeholder(int)) self.assertEqual(config.to_dict(), {'a': 1, 'b': None}) config.b = 5 self.assertEqual(config.to_dict(), {'a': 1, 'b': 5}) def testTypeChecking(self): config = config_dict.create(a=1, b=config_dict.placeholder(int)) with self.assertRaises(TypeError): config.b = 'chutney' def testRequired(self): config = config_dict.create(a=config_dict.required_placeholder(int)) ref = config.get_ref('a') with self.assertRaises(config_dict.RequiredValueError): config.a # pylint: disable=pointless-statement with self.assertRaises(config_dict.RequiredValueError): config.to_dict() with self.assertRaises(config_dict.RequiredValueError): ref.get() config.a = 10 self.assertEqual(config.to_dict(), {'a': 10}) self.assertEqual(str(config), yaml.dump({'a': 10})) # Reset to None and check we still get an error. config.a = None with self.assertRaises(config_dict.RequiredValueError): config.a # pylint: disable=pointless-statement # Set to a different value using the reference obtained calling get_ref(). ref.set(5) self.assertEqual(config.to_dict(), {'a': 5}) self.assertEqual(str(config), yaml.dump({'a': 5})) # dict placeholder. test_dict = {'field': 10} config = config_dict.create( a=config_dict.required_placeholder(dict), b=ml_collections.FieldReference(test_dict.copy())) # ConfigDict initialization converts dict to ConfigDict. self.assertEqual(test_dict, config.b.to_dict()) config.a = test_dict self.assertEqual(test_dict, config.a) class CycleTest(absltest.TestCase): def testCycle(self): config = config_dict.create(a=1) config.b = config.get_ref('a') + config.get_ref('a') self.assertFalse(config.get_ref('b').has_cycle()) with self.assertRaises(config_dict.MutabilityError): config.a = config.get_ref('a') with self.assertRaises(config_dict.MutabilityError): config.a = config.get_ref('b') if __name__ == '__main__': absltest.main() ml_collections-0.1.1/ml_collections/config_dict/tests/frozen_config_dict_test.py0000640000175000017500000003470014174507605027742 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for ml_collections.FrozenConfigDict.""" from collections import abc as collections_abc import copy import pickle from absl.testing import absltest import ml_collections _TEST_DICT = { 'int': 2, 'list': [1, 2], 'nested_list': [[1, [2]]], 'set': {1, 2}, 'tuple': (1, 2), 'frozenset': frozenset({1, 2}), 'dict': { 'float': -1.23, 'list': [1, 2], 'dict': {}, 'tuple_containing_list': (1, 2, (3, [4, 5], (6, 7))), 'list_containing_tuple': [1, 2, [3, 4], (5, 6)], }, 'ref': ml_collections.FieldReference({'int': 0}) } def _test_dict_deepcopy(): return copy.deepcopy(_TEST_DICT) def _test_configdict(): return ml_collections.ConfigDict(_TEST_DICT) def _test_frozenconfigdict(): return ml_collections.FrozenConfigDict(_TEST_DICT) class FrozenConfigDictTest(absltest.TestCase): """Tests FrozenConfigDict in config flags library.""" def assertFrozenRaisesValueError(self, input_list): """Assert initialization on all elements of input_list raise ValueError.""" for initial_dictionary in input_list: with self.assertRaises(ValueError): _ = ml_collections.FrozenConfigDict(initial_dictionary) def testBasicEquality(self): """Tests basic equality with different types of initialization.""" fcd = _test_frozenconfigdict() fcd_cd = ml_collections.FrozenConfigDict(_test_configdict()) fcd_fcd = ml_collections.FrozenConfigDict(fcd) self.assertEqual(fcd, fcd_cd) self.assertEqual(fcd, fcd_fcd) def testImmutability(self): """Tests immutability of frozen config.""" fcd = _test_frozenconfigdict() self.assertEqual(fcd.list, tuple(_TEST_DICT['list'])) self.assertEqual(fcd.tuple, _TEST_DICT['tuple']) self.assertEqual(fcd.set, frozenset(_TEST_DICT['set'])) self.assertEqual(fcd.frozenset, _TEST_DICT['frozenset']) # Must manually check set to frozenset conversion, since Python == does not self.assertIsInstance(fcd.set, frozenset) self.assertEqual(fcd.dict.list, tuple(_TEST_DICT['dict']['list'])) self.assertNotEqual(fcd.dict.tuple_containing_list, _TEST_DICT['dict']['tuple_containing_list']) self.assertEqual(fcd.dict.tuple_containing_list[2][1], tuple(_TEST_DICT['dict']['tuple_containing_list'][2][1])) self.assertIsInstance(fcd.dict, ml_collections.FrozenConfigDict) with self.assertRaises(AttributeError): fcd.newitem = 0 with self.assertRaises(AttributeError): fcd.dict.int = 0 with self.assertRaises(AttributeError): fcd['newitem'] = 0 with self.assertRaises(AttributeError): del fcd.int with self.assertRaises(AttributeError): del fcd['int'] def testLockAndFreeze(self): """Ensures .lock() and .freeze() raise errors.""" fcd = _test_frozenconfigdict() self.assertFalse(fcd.is_locked) self.assertFalse(fcd.as_configdict().is_locked) with self.assertRaises(AttributeError): fcd.lock() with self.assertRaises(AttributeError): fcd.unlock() with self.assertRaises(AttributeError): fcd.freeze() with self.assertRaises(AttributeError): fcd.unfreeze() def testInitConfigDict(self): """Tests that ConfigDict initialization handles FrozenConfigDict. Initializing a ConfigDict on a dictionary with FrozenConfigDict values should unfreeze these values. """ dict_without_fcd_node = _test_dict_deepcopy() dict_without_fcd_node.pop('ref') dict_with_fcd_node = copy.deepcopy(dict_without_fcd_node) dict_with_fcd_node['dict'] = ml_collections.FrozenConfigDict( dict_with_fcd_node['dict']) cd_without_fcd_node = ml_collections.ConfigDict(dict_without_fcd_node) cd_with_fcd_node = ml_collections.ConfigDict(dict_with_fcd_node) fcd_without_fcd_node = ml_collections.FrozenConfigDict( dict_without_fcd_node) fcd_with_fcd_node = ml_collections.FrozenConfigDict(dict_with_fcd_node) self.assertEqual(cd_without_fcd_node, cd_with_fcd_node) self.assertEqual(fcd_without_fcd_node, fcd_with_fcd_node) def testInitCopying(self): """Tests that initialization copies when and only when necessary. Ensures copying only occurs when converting mutable type to immutable type, regardless of whether the FrozenConfigDict is initialized by a dict or a FrozenConfigDict. Also ensures no copying occurs when converting from FrozenConfigDict back to ConfigDict. """ fcd = _test_frozenconfigdict() # These should be uncopied when creating fcd fcd_unchanged_from_test_dict = [ (_TEST_DICT['tuple'], fcd.tuple), (_TEST_DICT['frozenset'], fcd.frozenset), (_TEST_DICT['dict']['tuple_containing_list'][2][2], fcd.dict.tuple_containing_list[2][2]), (_TEST_DICT['dict']['list_containing_tuple'][3], fcd.dict.list_containing_tuple[3]) ] # These should be copied when creating fcd fcd_different_from_test_dict = [ (_TEST_DICT['list'], fcd.list), (_TEST_DICT['dict']['tuple_containing_list'][2][1], fcd.dict.tuple_containing_list[2][1]) ] for (x, y) in fcd_unchanged_from_test_dict: self.assertEqual(id(x), id(y)) for (x, y) in fcd_different_from_test_dict: self.assertNotEqual(id(x), id(y)) # Also make sure that converting back to ConfigDict makes no copies self.assertEqual( id(_TEST_DICT['dict']['tuple_containing_list']), id(ml_collections.ConfigDict(fcd).dict.tuple_containing_list)) def testAsConfigDict(self): """Tests that converting FrozenConfigDict to ConfigDict works correctly. In particular, ensures that FrozenConfigDict does the inverse of ConfigDict regarding type_safe, lock, and attribute mutability. """ # First ensure conversion to ConfigDict works on empty FrozenConfigDict self.assertEqual( ml_collections.ConfigDict(ml_collections.FrozenConfigDict()), ml_collections.ConfigDict()) cd = _test_configdict() cd_fcd_cd = ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd)) self.assertEqual(cd, cd_fcd_cd) # Make sure locking is respected cd.lock() self.assertEqual( cd, ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd))) # Make sure type_safe is respected cd = ml_collections.ConfigDict(_TEST_DICT, type_safe=False) self.assertEqual( cd, ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd))) def testInitSelfReferencing(self): """Ensure initialization fails on self-referencing dicts.""" self_ref = {} self_ref['self'] = self_ref parent_ref = {'dict': {}} parent_ref['dict']['parent'] = parent_ref tuple_parent_ref = {'dict': {}} tuple_parent_ref['dict']['tuple'] = (1, 2, tuple_parent_ref) attribute_cycle = {'dict': copy.deepcopy(self_ref)} self.assertFrozenRaisesValueError( [self_ref, parent_ref, tuple_parent_ref, attribute_cycle]) def testInitCycles(self): """Ensure initialization fails if an attribute of input is cyclic.""" inner_cyclic_list = [1, 2] cyclic_list = [3, inner_cyclic_list] inner_cyclic_list.append(cyclic_list) cyclic_tuple = tuple(cyclic_list) test_dict_cyclic_list = _test_dict_deepcopy() test_dict_cyclic_tuple = _test_dict_deepcopy() test_dict_cyclic_list['cyclic_list'] = cyclic_list test_dict_cyclic_tuple['dict']['cyclic_tuple'] = cyclic_tuple self.assertFrozenRaisesValueError( [test_dict_cyclic_list, test_dict_cyclic_tuple]) def testInitDictInList(self): """Ensure initialization fails on dict and ConfigDict in lists/tuples.""" list_containing_dict = {'list': [1, 2, 3, {'a': 4, 'b': 5}]} tuple_containing_dict = {'tuple': (1, 2, 3, {'a': 4, 'b': 5})} list_containing_cd = {'list': [1, 2, 3, _test_configdict()]} tuple_containing_cd = {'tuple': (1, 2, 3, _test_configdict())} fr_containing_list_containing_dict = { 'fr': ml_collections.FieldReference([1, { 'a': 2 }]) } self.assertFrozenRaisesValueError([ list_containing_dict, tuple_containing_dict, list_containing_cd, tuple_containing_cd, fr_containing_list_containing_dict ]) def testInitFieldReferenceInList(self): """Ensure initialization fails on FieldReferences in lists/tuples.""" list_containing_fr = {'list': [1, 2, 3, ml_collections.FieldReference(4)]} tuple_containing_fr = { 'tuple': (1, 2, 3, ml_collections.FieldReference('a')) } self.assertFrozenRaisesValueError([list_containing_fr, tuple_containing_fr]) def testInitInvalidAttributeName(self): """Ensure initialization fails on attributes with invalid names.""" dot_name = {'dot.name': None} immutable_name = {'__hash__': None} with self.assertRaises(ValueError): ml_collections.FrozenConfigDict(dot_name) with self.assertRaises(AttributeError): ml_collections.FrozenConfigDict(immutable_name) def testFieldReferenceResolved(self): """Tests that FieldReferences are resolved.""" cfg = ml_collections.ConfigDict({'fr': ml_collections.FieldReference(1)}) frozen_cfg = ml_collections.FrozenConfigDict(cfg) self.assertNotIsInstance(frozen_cfg._fields['fr'], ml_collections.FieldReference) hash(frozen_cfg) # with FieldReference resolved, frozen_cfg is hashable def testFieldReferenceCycle(self): """Tests that FieldReferences may not contain reference cycles.""" frozenset_fr = {'frozenset': frozenset({1, 2})} frozenset_fr['fr'] = ml_collections.FieldReference( frozenset_fr['frozenset']) list_fr = {'list': [1, 2]} list_fr['fr'] = ml_collections.FieldReference(list_fr['list']) cyclic_fr = {'a': 1} cyclic_fr['fr'] = ml_collections.FieldReference(cyclic_fr) cyclic_fr_parent = {'dict': {}} cyclic_fr_parent['dict']['fr'] = ml_collections.FieldReference( cyclic_fr_parent) # FieldReference is allowed to point to non-cyclic objects: _ = ml_collections.FrozenConfigDict(frozenset_fr) _ = ml_collections.FrozenConfigDict(list_fr) # But not cycles: self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent]) def testDeepCopy(self): """Ensure deepcopy works and does not affect equality.""" fcd = _test_frozenconfigdict() fcd_deepcopy = copy.deepcopy(fcd) self.assertEqual(fcd, fcd_deepcopy) def testEquals(self): """Tests that __eq__() respects hidden mutability.""" fcd = _test_frozenconfigdict() # First, ensure __eq__() returns False when comparing to other types self.assertNotEqual(fcd, (1, 2)) self.assertNotEqual(fcd, fcd.as_configdict()) list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) fcd_list_to_tuple = ml_collections.FrozenConfigDict(list_to_tuple) set_to_frozenset = _test_dict_deepcopy() set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) fcd_set_to_frozenset = ml_collections.FrozenConfigDict(set_to_frozenset) self.assertNotEqual(fcd, fcd_list_to_tuple) # Because set == frozenset in Python: self.assertEqual(fcd, fcd_set_to_frozenset) # Items are not affected by hidden mutability self.assertCountEqual(fcd.items(), fcd_list_to_tuple.items()) self.assertCountEqual(fcd.items(), fcd_set_to_frozenset.items()) def testEqualsAsConfigDict(self): """Tests that eq_as_configdict respects hidden mutability but not type.""" fcd = _test_frozenconfigdict() # First, ensure eq_as_configdict() returns True with an equal ConfigDict but # False for other types. self.assertFalse(fcd.eq_as_configdict([1, 2])) self.assertTrue(fcd.eq_as_configdict(fcd.as_configdict())) empty_fcd = ml_collections.FrozenConfigDict() self.assertTrue(empty_fcd.eq_as_configdict(ml_collections.ConfigDict())) # Now, ensure it has the same immutability detection as __eq__(). list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) fcd_list_to_tuple = ml_collections.FrozenConfigDict(list_to_tuple) set_to_frozenset = _test_dict_deepcopy() set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) fcd_set_to_frozenset = ml_collections.FrozenConfigDict(set_to_frozenset) self.assertFalse(fcd.eq_as_configdict(fcd_list_to_tuple)) # Because set == frozenset in Python: self.assertTrue(fcd.eq_as_configdict(fcd_set_to_frozenset)) def testHash(self): """Ensures __hash__() respects hidden mutability.""" list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) self.assertEqual( hash(_test_frozenconfigdict()), hash(ml_collections.FrozenConfigDict(_test_dict_deepcopy()))) self.assertNotEqual( hash(_test_frozenconfigdict()), hash(ml_collections.FrozenConfigDict(list_to_tuple))) # Ensure Python realizes FrozenConfigDict is hashable self.assertIsInstance(_test_frozenconfigdict(), collections_abc.Hashable) def testUnhashableType(self): """Ensures __hash__() fails if FrozenConfigDict has unhashable value.""" unhashable_fcd = ml_collections.FrozenConfigDict( {'unhashable': bytearray()}) with self.assertRaises(TypeError): hash(unhashable_fcd) def testToDict(self): """Ensure to_dict() does not care about hidden mutability.""" list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) self.assertEqual(_test_frozenconfigdict().to_dict(), ml_collections.FrozenConfigDict(list_to_tuple).to_dict()) def testPickle(self): """Make sure FrozenConfigDict can be dumped and loaded with pickle.""" fcd = _test_frozenconfigdict() locked_fcd = ml_collections.FrozenConfigDict(_test_configdict().lock()) unpickled_fcd = pickle.loads(pickle.dumps(fcd)) unpickled_locked_fcd = pickle.loads(pickle.dumps(locked_fcd)) self.assertEqual(fcd, unpickled_fcd) self.assertEqual(locked_fcd, unpickled_locked_fcd) if __name__ == '__main__': absltest.main() ml_collections-0.1.1/ml_collections/config_dict/config_dict.py0000640000175000017500000020733414174507605024163 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Classes for defining configurations of experiments and models. This file defines the classes `ConfigDict` and `FrozenConfigDict`, which are "dict-like" data structures with Lua-like access and some other nice features. Together, they are supposed to be used as a main way of expressing configurations of experiments and models. """ import abc import collections from collections import abc as collections_abc import contextlib import difflib import functools import inspect import json import operator from typing import Any, Mapping, Optional from absl import logging import contextlib2 import six import yaml from yaml import representer # Workaround for https://github.com/yaml/pyyaml/issues/36. Classes that have # `abc.ABCMeta` as a metaclass are incorrectly handled as objects. This results # in the unbound `__reduce_ex__` method being called with the protocol version # as its sole argument, resulting in a `TypeError`. A solution is to add a # custom representer that represents `abc.ABCMeta` by name. representer.Representer.add_representer( data_type=abc.ABCMeta, representer=representer.Representer.represent_name) class RequiredValueError(Exception): pass class MutabilityError(Exception): pass class JSONDecodeError(Exception): pass _NoneType = type(None) def _is_callable_type(field_type): """Tries to ensure: `_is_callable_type(type(obj)) == callable(obj)`.""" return any('__call__' in c.__dict__ for c in field_type.__mro__) def _is_type_safety_violation(value, field_type): """Helper function for type safety exceptions. This function determines whether or not assigning a value to a field violates type safety. Args: value: The value to be assigned. field_type: Type of the field that we would like to assign value to. Returns: True if assigning value to field violates type safety, False otherwise. """ # Allow None to override and be overridden by any type. if value is None or field_type == _NoneType: return False elif isinstance(value, field_type): return False else: # A callable can overridde a callable. return not (callable(value) and _is_callable_type(field_type)) def _safe_cast(value, field_type, type_safe=False): """Helper function to handle the exceptional type conversions. This function implements the following exceptions for type-checking rules: * An `int` will be converted to a `float` if overriding a `float` field. * Any string value can override a `str` or `unicode` field. The value is converted to `field_type`. * A `tuple` will be converted to a `list` if overriding a `list` field. * A `list` will be converted to a `tuple` if overriding `tuple` field. * Short and long integers are indistinguishable. The final value will always be a `long` if both types are present. Args: value: The value to be assigned. field_type: The type for the field that we would like to assign value to. type_safe: If True, the method will throw an error if the `value` is not of type `field_type` after safe type conversions. Returns: The converted type-safe version of the value if it is one of the cases described. Otherwise, return the value without conversion. Raises: TypeError: if types don't match after safe type conversions. """ original_value_type = type(value) # The int->float exception. if isinstance(value, int) and field_type is float: return float(value) # The unicode/string to string exception. if isinstance(value, six.string_types) and field_type in (str, six.text_type): return field_type(value) # tuple<->list conversion. JSON serialization converts lists to tuples, so # we need this to avoid errors when overriding a list field with its # deserialized version. See b/34805906 for more details. if isinstance(value, tuple) and field_type is list: return list(value) if isinstance(value, list) and field_type is tuple: return tuple(value) if isinstance(value, six.integer_types) and field_type in six.integer_types: if six.PY2 and (isinstance(value, long) or field_type is long): # Overriding an int with a long and viceversa should be possible. # https://www.python.org/dev/peps/pep-0237/ return long(value) else: # In Python 3, there is only the `int` type. return value if type_safe and _is_type_safety_violation(value, field_type): raise TypeError('{} is of type {} but should be of type {}' .format(value, str(original_value_type), str(field_type))) return value def _get_computed_value(value_or_fieldreference): if isinstance(value_or_fieldreference, FieldReference): return value_or_fieldreference.get() return value_or_fieldreference class _Op(collections.namedtuple('_Op', ['fn', 'args'])): """A named tuple representing a lazily computed op. The _Op named tuple has two fields: fn: The function to be applied. args: a tuple/list of arguments that are used with the op. """ @functools.total_ordering class FieldReference: """Reference to a configuration element. Typed configuration element that can take a None default value. Example:: from ml_collections import config_dict cfg_field = config_dict.FieldReference(0) cfg = config_dict.ConfigDict({ 'optional': configdict.FieldReference(None, field_type=str) 'field': cfg_field, 'nested': {'field': cfg_field} }) with self.assertRaises(TypeError): cfg.optional = 10 # Raises an error because it's defined as an # intfield. cfg.field = 1 # Changes the value of both cfg.field and cfg.nested.field. print(cfg) This class also supports lazy computation. Example:: ref = config_dict.FieldReference(0) # Using ref in a standard operation returns another FieldReference. The new # reference ref_plus_ten will evaluate ref's value only when we call # ref_plus_ten.get() ref_plus_ten = ref + 10 ref.set(3) # change ref's value print(ref_plus_ten.get()) # Prints 13 because ref's value is 3 ref.set(-2) # change ref's value again print(ref_plus_ten.get()) # Prints 8 because ref's value is -2 """ def __init__(self, default, field_type=None, op=None, required=False): """Creates an instance of FieldReference. Args: default: Default value. field_type: Type for the values contained by the configuration element. If None the type will be inferred from the default value. This value is used as the second argument in calls to isinstance, so it has to follow that function's requirements (class, type or a tuple containing classes, types or tuples). op: An optional operation that is applied to the underlying value when `get()` is called. required: If True, the `get()` method will raise an error if the reference does not contain a value. This argument has no effect when a default value is provided. Setting this to True will raise an error if `op` is not None. Raises: TypeError: If field_type is not None and is different from the type of the default value. ValueError: If both the default value and field_type is None. """ if field_type is None: if default is None: raise ValueError('Default value cannot be None if field_type is None') elif isinstance(default, FieldReference): field_type = default.get_type() else: field_type = type(default) else: # Early error when field_type doesn't a structure compatible with # isinstance (class, type or tuple containing classes, types or tuples. # The easiest way to check this is call isinstance and catch TypeError # exceptions. try: isinstance(None, field_type) except TypeError: raise TypeError('field_type should be a type, not {}' .format(type(field_type))) self._field_type = field_type self.set(default) if required and op is not None: raise ValueError('Cannot set required to True if op is not None') self._required = required self._ops = [] if op is None else [op] def has_cycle(self, visited=None): """Finds cycles in the reference graph. Args: visited: Set containing the ids of all visited nodes in the graph. The default value is the empty set. Returns: True if there is a cycle in the reference graph. """ visited = visited or set() if id(self) in visited: return True visited.add(id(self)) # Verify the reference to the parent FieldReference doesn't introduce a # cycle. value = self._value if isinstance(value, FieldReference) and value.has_cycle(visited.copy()): return True # Verify references in the operator arguments don't introduce cycles. for op in self._ops: for arg in op.args: if isinstance(arg, FieldReference) and arg.has_cycle(visited.copy()): return True return False def set(self, value, type_safe=True): """Overwrites the value pointed by a FieldReference. Args: value: New value. type_safe: Check that old and new values are of the same type. Raises: TypeError: If type_safe is true and old and new values are not of the same type. MutabilityError: If a cycle is found in the reference graph. """ # Disable ops. self._ops = [] if value is None: self._value = None elif isinstance(value, FieldReference): if type_safe and value.get_type() is not self.get_type(): raise TypeError('Reference is of type {} but should be of type {}' .format(value.get_type(), self.get_type())) old_value = getattr(self, '_value', None) self._value = value if self.has_cycle(): self._value = old_value raise MutabilityError('Found cycle in reference graph.') else: # TODO(sergomez): Update reference type. self._value = _safe_cast(value, self._field_type, type_safe) def empty(self): """Returns True if the reference points to a None value.""" return self._value is None def get(self): """Gets the value of the `FieldReference` object. This will dereference `_pointer` and apply all ops to its value. Returns: The result of applying all ops to the dereferenced pointer. Raises: RequiredValueError: if `required` is True and the underlying value for the reference is False. """ if self._required and self._value is None: raise RequiredValueError('None value found in required reference') value = _get_computed_value(self._value) for op in self._ops: # Dereference any FieldReference objects args = [_get_computed_value(arg) for arg in op.args] if value is None or None in args: value = None logging.debug('Cannot apply `%s` to `None`; skipping.', op) else: value = op.fn(value, *args) value = _get_computed_value(value) return value def get_type(self): return self._field_type def __eq__(self, other): if isinstance(other, FieldReference): return self.get() == other.get() else: return self.get() == other def __le__(self, other): if isinstance(other, FieldReference): return self.get() <= other.get() else: return self.get() <= other # Make FieldReference unhashable (as it's mutable). __hash__ = None def _apply_op(self, fn, *args): args = [_safe_cast(arg, self._field_type) for arg in args] return FieldReference( self, field_type=self._field_type, op=_Op(fn, args)) def _apply_cast_op(self, field_type): """Apply a cast op that changes the field_type of this FieldReference. `_apply_op` assumes that the `field_type` does not change after the op is applied whereas `_apply_cast_op` generates a FieldReference with casted field_type. Since `fn(value, *args)` we need to ignore `value` which now contains a dummy default value of field_type. Args: field_type: data type to cast to. Returns: A new FieldReference with of `field_type`. """ return FieldReference( field_type(), # Creates dummy default value matching `field_type`. field_type=field_type, op=_Op(lambda _, val: field_type(val), # `fn` ignores `field_type()`. [self]), ) def identity(self): return self._apply_op(lambda x: x) def attr(self, attr_name): return self._apply_op(operator.attrgetter(attr_name)) def __add__(self, other): return self._apply_op(operator.add, other) def __radd__(self, other): radd = functools.partial(operator.add, other) return self._apply_op(radd) def __sub__(self, other): return self._apply_op(operator.sub, other) def __rsub__(self, other): rsub = functools.partial(operator.sub, other) return self._apply_op(rsub) def __mul__(self, other): return self._apply_op(operator.mul, other) def __rmul__(self, other): rmul = functools.partial(operator.mul, other) return self._apply_op(rmul) def __div__(self, other): return self._apply_op(operator.truediv, other) def __rdiv__(self, other): rdiv = functools.partial(operator.truediv, other) return self._apply_op(rdiv) def __truediv__(self, other): return self._apply_op(operator.truediv, other) def __rtruediv__(self, other): rtruediv = functools.partial(operator.truediv, other) return self._apply_op(rtruediv) def __floordiv__(self, other): return self._apply_op(operator.floordiv, other) def __rfloordiv__(self, other): rfloordiv = functools.partial(operator.floordiv, other) return self._apply_op(rfloordiv) def __pow__(self, other): return self._apply_op(operator.pow, other) def __mod__(self, other): return self._apply_op(operator.mod, other) def __and__(self, other): return self._apply_op(operator.and_, other) def __or__(self, other): return self._apply_op(operator.or_, other) def __xor__(self, other): return self._apply_op(operator.xor, other) def __neg__(self): return self._apply_op(operator.neg) def __abs__(self): return self._apply_op(operator.abs) def to_int(self): return self._apply_cast_op(int) def to_float(self): return self._apply_cast_op(float) def to_str(self): return self._apply_cast_op(str) def __setstate__(self, state): self._value = state['_value'] self._field_type = state['_field_type'] self._ops = state['_ops'] # TODO(sergomez): Remove default for _required (and potentially the whole # __setstate__ method) after June 2019. self._required = state.get('_required', False) def __nonzero__(self): raise NotImplementedError( 'FieldReference cannot be used for control flow. For boolean ' 'operations use "&" (logical "and") or "|" (logical "or").') def __bool__(self): raise NotImplementedError( 'FieldReference cannot be used for control flow. For boolean ' 'operations use "&" (logical "and") or "|" (logical "or").') def _configdict_fill_seed(seed, initial_dictionary, visit_map=None): """Fills an empty ConfigDict without copying previously visited nodes. Turns seed (an empty ConfigDict) into a ConfigDict version of initial_dictionary. Avoids infinite looping on a self-referencing initial_dictionary because if a value of initial_dictionary has been previously visited, that value is not re-converted to a ConfigDict. If a FieldReference is encountered which contains a dict or FrozenConfigDict, its contents will be converted to ConfigDict. Note: As described in the __init__() documentation, this will not replicate the structure of initial_dictionary if it contains self-references within lists, tuples, or other types. There is no warning or error in this case. Args: seed: Empty ConfigDict, to be filled in. initial_dictionary: The template on which seed is built. May be of type dict, ConfigDict or FrozenConfigDict. visit_map: Dictionary from memory addresses to values, storing the ConfigDict versions of dictionary values. visit_map need not contain (id(initial_dictionary), seed) as a key/value pair. Raises: TypeError: If seed is not a ConfigDict. ValueError: If seed is not an empty ConfigDict. """ # These should be impossible to raise, since the public call-site in # __init__() pass in valid input, as does this method recursively. assert isinstance(seed, ConfigDict) assert not seed visit_map = visit_map or {} visit_map[id(initial_dictionary)] = seed if isinstance(initial_dictionary, ConfigDict): iteritems = initial_dictionary.iteritems(preserve_field_references=True) else: iteritems = six.iteritems(initial_dictionary) for key, value in iteritems: if id(value) in visit_map: value = visit_map[id(value)] elif (isinstance(value, FieldReference) and value.get_type() is dict and seed.convert_dict): # If the reference is empty, we don't have to do dict -> ConfigDict # conversion. # Calling get() on an empty required reference would raise an error so we # need a special case for this. if value.empty(): pass elif id(value.get()) in visit_map: value.set(visit_map[id(value.get())], False) else: value_cd = ConfigDict(type_safe=seed.is_type_safe) _configdict_fill_seed(value_cd, value.get(), visit_map) value.set(value_cd, False) elif isinstance(value, dict) and seed.convert_dict: value_cd = ConfigDict(type_safe=seed.is_type_safe) _configdict_fill_seed(value_cd, value, visit_map) value = value_cd elif isinstance(value, FrozenConfigDict): value = ConfigDict(value) seed.__setattr__(key, value) class ConfigDict: # pylint: disable=line-too-long """Base class for configuration objects used in DeepMind. This is a container for configurations. It behaves similarly to Lua tables. Specifically: - it has dot-based access as well as dict-style key access, - it is type safe (once a value is set one cannot change its type). Typical usage example:: from ml_collections import config_dict cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg) Config dictionaries can also be used to pass named arguments to functions:: from ml_collections import config_dict def print_point(x, y): print "({},{})".format(x, y) point = config_dict.ConfigDict() point.x = 1 point.y = 2 print_point(**point) Note that, depending on your use case, it may be easier to use the `create` function in this package to construct a `ConfigDict`:: from ml_collections.config_dict import config_dict point = config_dict.create(x=1, y=2) Differently from standard `dicts`, `ConfigDicts` also have the nice property that iterating over them is deterministic, in a fashion similar to `collections.OrderedDicts`. """ # pylint: enable=line-too-long # Loosen the static type checking requirements. _HAS_DYNAMIC_ATTRIBUTES = True def __init__( self, initial_dictionary: Optional[Mapping[str, Any]] = None, type_safe: bool = True, convert_dict: bool = True, ): """Creates an instance of ConfigDict. Warning: In most cases, this faithfully reproduces the reference structure of initial_dictionary, even if initial_dictionary is self-referencing. However, unexpected behavior occurs if self-references are contained within list, tuple, or custom types. For example:: d = {} d['a'] = d d['b'] = [d] cd = ConfigDict(d) cd.a # refers to cd, type ConfigDict. Expected behavior. cd.b # refers to d, type dict. Unexpected behavior. Warning: FieldReference values may be changed. If initial_dictionary contains a FieldReference with a value of type dict or FrozenConfigDict, that value is converted to ConfigDict. Args: initial_dictionary: May be one of the following: 1) dict. In this case, all values of initial_dictionary that are dictionaries are also be converted to ConfigDict. However, dictionaries within values of non-dict type are untouched. 2) ConfigDict. In this case, all attributes are uncopied, and only the top-level object (self) is re-addressed. This is the same behavior as Python dict, list, and tuple. 3) FrozenConfigDict. In this case, initial_dictionary is converted to a ConfigDict version of the initial dictionary for the FrozenConfigDict (reversing any mutability changes FrozenConfigDict made). type_safe: If set to True, once an attribute value is assigned, its type cannot be overridden without .ignore_type() context manager (default: True). convert_dict: If set to True, all dict used as value in the ConfigDict will automatically be converted to ConfigDict (default: True). """ if isinstance(initial_dictionary, FrozenConfigDict): initial_dictionary = initial_dictionary.as_configdict() super(ConfigDict, self).__setattr__('_fields', {}) super(ConfigDict, self).__setattr__('_locked', False) super(ConfigDict, self).__setattr__('_type_safe', type_safe) super(ConfigDict, self).__setattr__('_convert_dict', convert_dict) if initial_dictionary is not None: _configdict_fill_seed(self, initial_dictionary) if isinstance(initial_dictionary, ConfigDict): super(ConfigDict, self).__setattr__('_locked', initial_dictionary.is_locked) super(ConfigDict, self).__setattr__('_type_safe', initial_dictionary.is_type_safe) @property def is_type_safe(self) -> bool: """Returns True if config dict is type safe.""" return self._type_safe @property def convert_dict(self): """Returns True if it is converting dicts to ConfigDict automatically.""" return self._convert_dict def lock(self) -> 'ConfigDict': """Locks object, preventing user from adding new fields. Returns: self """ if self.is_locked: return self super(ConfigDict, self).__setattr__('_locked', True) for field in self._fields: element = self._fields[field] element = _get_computed_value(element) if isinstance(element, ConfigDict): element.lock() return self @property def is_locked(self) -> bool: """Returns True if object is locked.""" return self._locked def unlock(self) -> 'ConfigDict': """Grants user the ability to add new fields to ConfigDict. In most cases, the unlocked() context manager should be preferred to the direct use of the unlock method. Returns: self """ super(ConfigDict, self).__setattr__('_locked', False) for element in six.itervalues(self._fields): element = _get_computed_value(element) if isinstance(element, ConfigDict) and element.is_locked: element.unlock() return self def get(self, key: str, default=None): """Returns value if key is present, or a user defined value otherwise.""" try: return self[key] except KeyError: return default # TODO(sergomez): replace this with get_oneway_ref. The first step is to log # usage patterns of this. How many users are overriding the value of the # reference returned by this and expect the referenced field to change too? def get_ref(self, key): """Returns a FieldReference initialized on key's value.""" field = self._fields[key] if field is None: raise ValueError('Cannot create reference to a field whose value is None') if not isinstance(field, FieldReference): field = FieldReference(field) with self.ignore_type(): self[key] = field return field def get_oneway_ref(self, key): """Returns a one-way FieldReference. Example:: cfg = ml_collections.ConfigDict(dict(a=1)) cfg.b = cfg.get_oneway_ref('a') cfg.a = 2 print(cfg.b) # 2 cfg.b = 3 print(cfg.a) # 2 (would have been 3 if using get_ref()) print(cfg.b) # 3 Args: key: Key for field we want to reference. """ # Using the result of applying an operation on the reference means that # calling set() on this object won't propagate the new value up the # reference chain. return self.get_ref(key).identity() def items(self, preserve_field_references=False): """Returns list of dictionary key, value pairs, sorted by key. Args: preserve_field_references: (bool) Whether to preserve FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved in the result. Returns: The key, value pairs in the config, sorted by key. """ if preserve_field_references: return six.iteritems(self._ordered_fields) else: return [(k, self[k]) for k in self._ordered_fields] @property def _ordered_fields(self): """Returns ordered dict shallow cast of _fields member.""" return collections.OrderedDict(sorted(self._fields.items())) def iteritems(self, preserve_field_references=False): """Deterministically iterates over dictionary key, value pairs. Args: preserve_field_references: (bool) Whether to preserve FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved in the result. Yields: The key, value pairs in the config, sorted by key. """ for k in self._ordered_fields: if preserve_field_references: yield k, self._fields[k] else: yield k, self[k] def _ensure_mutability(self, attribute): if attribute in dir(super(ConfigDict, self)): raise KeyError('{} cannot be overridden.'.format(attribute)) def __setattr__(self, attribute, value): try: self._ensure_mutability(attribute) self[attribute] = value except KeyError as e: raise AttributeError(e) def __delattr__(self, attribute): try: self._ensure_mutability(attribute) del self[attribute] except KeyError as e: raise AttributeError(e) def __getattr__(self, attribute): try: return self[attribute] except KeyError as e: raise AttributeError(e) def __setitem__(self, key, value): if '.' in key: raise ValueError('ConfigDict does not accept dots in field names, but ' 'the key {} contains one.'.format(key)) if self.is_locked and key not in self._fields: error_msg = ('Key "{}" does not exist and cannot be added since the ' 'config is locked') raise KeyError( self._generate_did_you_mean_message(key, error_msg.format(key))) if key in self._fields: field = self._fields[key] try: if isinstance(field, FieldReference): field.set(value, self._type_safe) return # Skip type checking if the value is a FieldReference of the same type. if (not isinstance(value, FieldReference) or value.get_type() is not type(field)): if isinstance(value, dict) and self._convert_dict: value = ConfigDict(value, self._type_safe) value = _safe_cast(value, type(field), self._type_safe) except TypeError as e: raise TypeError('Could not override field \'{}\' (reference). {}' .format(key, str(e))) if isinstance(value, dict) and self._convert_dict: value = ConfigDict(value, self._type_safe) elif isinstance(value, FieldReference): # TODO(sergomez): We should consider using value.get_type(). ref_type = _NoneType if value.empty() else type(value.get()) if ref_type is dict or ref_type is FrozenConfigDict: value_cd = ConfigDict(value.get(), self._type_safe) value.set(value_cd, False) self._fields[key] = value def _generate_did_you_mean_message(self, request, message=''): matches = difflib.get_close_matches(request, self.keys(), 1, 0.75) if matches: if message: message += '\n' message += 'Did you mean "{}" instead of "{}"?'.format(matches[0], request) return message def __delitem__(self, key: str): if self.is_locked: raise KeyError('This ConfigDict is locked, you have to unlock it before ' 'trying to delete a field.') if '.' in key: # As per the check in __setitem__ above, keys cannot contain dots. # Hence, we can use dots to do recursive calls. key, rest = key.split('.', 1) del self[key][rest] return try: del self._fields[key] except KeyError as e: raise KeyError(self._generate_did_you_mean_message(key, str(e))) def __getitem__(self, key: str): if '.' in key: # As per the check in __setitem__ above, keys cannot contain dots. # Hence, we can use dots to do recursive calls. key, rest = key.split('.', 1) return self[key][rest] try: field = self._fields[key] if isinstance(field, FieldReference): return field.get() else: return field except KeyError as e: raise KeyError(self._generate_did_you_mean_message(key, str(e))) def __contains__(self, key: str): return key in self._fields def __repr__(self) -> str: # We want __repr__ to always run without throwing an exception, # even if the config dict is not YAML serialisable. try: return yaml.dump(self.to_dict(preserve_field_references=True), default_flow_style=False) except Exception: # pylint: disable=broad-except return repr(self.to_dict()) def __str__(self) -> str: # We want __str__ to always run without throwing an exception, # even if the config dict is not YAML serialisable. try: return yaml.dump(self.to_dict()) except Exception: # pylint: disable=broad-except return str(self.to_dict()) def keys(self): """Returns the sorted list of all the keys defined in a config.""" return list(self._ordered_fields.keys()) def iterkeys(self): """Deterministically iterates over dictionary keys, in sorted order.""" return six.iterkeys(self._ordered_fields) def values(self, preserve_field_references=False): """Returns the list of all values in a config, sorted by their keys. Args: preserve_field_references: (bool) Whether to preserve FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved in the result. Returns: The values in the config, sorted by their corresponding keys. """ if preserve_field_references: return list(self._ordered_fields.values()) else: return [self[k] for k in self._ordered_fields] def itervalues(self, preserve_field_references=False): """Deterministically iterates over values in a config, sorted by their keys. Args: preserve_field_references: (bool) Whether to preserve FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved in the result. Yields: The values in the config, sorted by their corresponding keys. """ for k in self._ordered_fields: if preserve_field_references: yield self._fields[k] else: yield self[k] def __dir__(self): return self.keys() + dir(ConfigDict) def __len__(self): return self._ordered_fields.__len__() def __iter__(self): return self._ordered_fields.__iter__() def __eq__(self, other): """Override the default Equals behavior.""" if isinstance(other, self.__class__): same = self._fields == other._fields same &= self._locked == other.is_locked same &= self._type_safe == other.is_type_safe return same return False def __ne__(self, other): """Define a non-equality test.""" return not self.__eq__(other) def eq_as_configdict(self, other): """Type-invariant equals. This is like `__eq__`, except it does not distinguish `FrozenConfigDict` from `ConfigDict`. For example:: cd = ConfigDict() fcd = FrozenConfigDict() fcd.eq_as_configdict(cd) # Returns True Args: other: Object to compare self to. Returns: same: `True` if `self == other` after conversion to `ConfigDict`. """ if isinstance(other, ConfigDict): return ConfigDict(self) == ConfigDict(other) else: return False # Make ConfigDict unhashable __hash__ = None def to_yaml(self, **kwargs): """Returns a YAML representation of the object. ConfigDict serializes types of fields as well as the values of fields themselves. Deserializing the YAML representation hence requires using YAML's UnsafeLoader: ``` yaml.load(cfg.to_yaml(), Loader=yaml.UnsafeLoader) ``` or equivalently: ``` yaml.unsafe_load(cfg.to_yaml()) ``` Please see the PyYAML documentation and https://msg.pyyaml.org/load for more details on the consequences of this. Args: **kwargs: Keyword arguments for yaml.dump. Returns: YAML representation of the object. """ return yaml.dump(self, **kwargs) def _json_dumps_wrapper(self, **kwargs): """Wrapper for json.dumps() method. Produces a more informative error message when there is a problem with string encodings in the ConfigDict. Args: **kwargs: Keyword arguments for json.dumps. Returns: JSON representation of the object. Raises: JSONDecodeError: If there is a problem with string encodings. """ try: return json.dumps(self._fields, **kwargs) except UnicodeDecodeError as error: # Re-raise exception with more informative error message. new_message = ( 'Decoding error. Make sure all strings in your ConfigDict use ASCII-' 'compatible encodings. See ' 'https://docs.python.org/2.7/howto/unicode.html#the-unicode-type ' 'for details. Original error message: {}'.format(error)) raise JSONDecodeError(new_message) def to_json(self, json_encoder_cls=None, **kwargs): """Returns a JSON representation of the object, fails if there is a cycle. Args: json_encoder_cls: An optional JSON encoder class to customize JSON serialization. **kwargs: Keyword arguments for json.dumps. They cannot contain "cls" as this method specifies it on its own. Returns: JSON representation of the object. Raises: TypeError: If self contains set, frozenset, custom type fields or any other objects that are not JSON serializable. """ json_encoder_cls = json_encoder_cls or CustomJSONEncoder return self._json_dumps_wrapper(cls=json_encoder_cls, **kwargs) def to_json_best_effort(self, **kwargs): """Returns a best effort JSON representation of the object. Tries to serialize objects not inherently supported by JSON encoder. This may result in the configdict being partially serialized, skipping the unserializable bits. Ensures that no errors are thrown. Fails if there is a cycle. Args: **kwargs: Keyword arguments for json.dumps. They cannot contain "cls" as this method specifies it on its own. Returns: JSON representation of the object. """ return self._json_dumps_wrapper(cls=_BestEffortCustomJSONEncoder, **kwargs) def to_dict(self, visit_map=None, preserve_field_references=False): """Converts ConfigDict to regular dict recursively with valid references. By default, the output dict will not contain FieldReferences, any present in the ConfigDict will be resolved. However, if `preserve_field_references` is True, the output dict will contain FieldReferences where the original ConfigDict has them. They will not be the same as the ConfigDict's, and their ops will be applied and dropped. Note: As with __eq__() and __init__(), this may not behave as expected on a ConfigDict with self-references contained in lists, tuples, or custom types. Args: visit_map: A mapping from object ids to their dict representation. Method is recursive in nature, and it will call ".to_dict(visit_map)" on each encountered object, unless it is already in visit_map. preserve_field_references: (bool) Whether the output dict should have FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved and the result will go to the dict. Returns: Dictionary with the same values and references structure as a calling ConfigDict. """ visit_map = visit_map or {} dict_self = {} visit_map[id(self)] = dict_self for key in self: if (isinstance(self._fields[key], FieldReference) and preserve_field_references): reference = self._fields[key] value = reference.get() else: value = self[key] reference = value if id(reference) in visit_map: dict_self[key] = visit_map[id(reference)] elif isinstance(value, ConfigDict): if isinstance(reference, FieldReference): # Create a new reference of type dict instead of ConfigDict. old_reference = reference reference = FieldReference({}, dict) visit_map[id(old_reference)] = reference reference.set(value.to_dict(visit_map, preserve_field_references)) else: reference = value.to_dict(visit_map, preserve_field_references) dict_self[key] = reference else: if isinstance(reference, FieldReference): # Create a new reference to put in the new dict, which will be # reused whenever we find the same original reference. # Notice that ops are lost in the copy, but they are applied when # we do old_reference.get(). old_reference = reference # Disable type safety since value in the field reference might have # been previously set with type safety disabled (e.g. ignore_type # context, as in b/119393923). reference = FieldReference(None, old_reference.get_type()) reference.set(old_reference.get(), type_safe=False) visit_map[id(old_reference)] = reference dict_self[key] = reference return dict_self def copy_and_resolve_references(self, visit_map=None): """Returns a ConfigDict copy with FieldReferences replaced by values. If the object is a FrozenConfigDict, the copy returned is also a FrozenConfigDict. However, note that FrozenConfigDict should already have FieldReferences resolved to values, so this method effectively produces a deep copy. Note: As with __eq__() and __init__(), this may not behave as expected on a ConfigDict with self-references contained in lists, tuples, or custom types. Args: visit_map: A mapping from ConfigDict object ids to their copy. Method is recursive in nature, and it will call ".copy_and_resolve_references(visit_map)" on each encountered object, unless it is already in visit_map. Returns: ConfigDict copy with previous FieldReferences replaced by values. """ visit_map = visit_map or {} config_dict_copy = self.__class__() super(ConfigDict, config_dict_copy).__setattr__('_convert_dict', self.convert_dict) visit_map[id(self)] = config_dict_copy for key, value in six.iteritems(self._fields): if isinstance(value, FieldReference): value = value.get() if id(value) in visit_map: value = visit_map[id(value)] elif isinstance(value, ConfigDict): value = value.copy_and_resolve_references(visit_map) if isinstance(self, FrozenConfigDict): config_dict_copy._frozen_setattr( # pylint:disable=protected-access key, value) else: config_dict_copy[key] = value super(ConfigDict, config_dict_copy).__setattr__( '_locked', self.is_locked) super(ConfigDict, config_dict_copy).__setattr__( '_type_safe', self.is_type_safe) return config_dict_copy def __setstate__(self, state): """Recreates ConfigDict from its dict representation.""" self.__init__() super(ConfigDict, self).__setattr__('_type_safe', state['_type_safe']) super(ConfigDict, self).__setattr__('_convert_dict', state.get('_convert_dict', True)) for field in state['_fields']: self[field] = state['_fields'][field] if state['_locked']: self.lock() @contextlib.contextmanager def unlocked(self): """Context manager which temporarily unlocks a ConfigDict.""" was_locked = self._locked if was_locked: self.unlock() try: yield self finally: if was_locked: self.lock() @contextlib.contextmanager def ignore_type(self): """Context manager which temporarily turns off type safety recursively.""" original_type_safety = self._type_safe managers = [] visited = set() fields = list(six.iteritems(self._fields)) while fields: field = fields.pop() if id(field) in visited: continue visited.add(id(field)) if isinstance(field, ConfigDict): managers.append(field.ignore_type) # Recursively add elements in nested collections. elif isinstance(field, collections_abc.Mapping): fields.extend(six.iteritems(field)) elif isinstance(field, (collections_abc.Sequence, collections_abc.Set)): fields.extend(field) super(ConfigDict, self).__setattr__('_type_safe', False) try: with contextlib2.ExitStack() as stack: for manager in managers: stack.enter_context(manager()) yield self finally: super(ConfigDict, self).__setattr__('_type_safe', original_type_safety) def get_type(self, key): """Returns type of the field associated with a key.""" # We access the field outside of the if/else statement to raise in all cases # AttributeErrors (potentially including "did you mean" messages) for # non-existent keys. field = self.__getattr__(key) if isinstance(self._fields[key], FieldReference): return self._fields[key].get_type() else: return type(field) def update(self, *other, **kwargs): """Update values based on matching keys in another dict-like object. Mimics the built-in dict's update method: iterates over a given mapping object and adds/overwrites keys with the given mapping's values for those keys. Differs from dict.update in that it operates recursively on existing keys that are already a ConfigDict (i.e. calls their update() on the corresponding value from other), and respects the ConfigDict's type safety status. If keyword arguments are specified, the ConfigDict is updated with those key/value pairs. Args: *other: A (single) dict-like container, e.g. a dict or ConfigDict. **kwargs: Additional keyword arguments to update the ConfigDict. Raises: TypeError: if more than one value for `other` is specified. """ if len(other) > 1: raise TypeError( 'update expected at most 1 arguments, got {}'.format(len(other))) for other in other + (kwargs,): iteritems_kwargs = {} if isinstance(other, ConfigDict): iteritems_kwargs['preserve_field_references'] = True for key, value in six.iteritems(other, **iteritems_kwargs): # pytype: disable=wrong-keyword-args if key not in self: self[key] = value elif isinstance(self._fields[key], ConfigDict): self[key].update(other[key]) elif (isinstance(self._fields[key], FieldReference) and isinstance(value, FieldReference)): # Updating FieldReferences from FieldReferences is not allowed. # One option could be to just grab the value from `other` and try to # update the reference in `self` using that. But that could result in # losing links between fields in `other`. # Example: # other = ConfigDict(dict(a=1)) # other.b = other.get_ref('a') # this = ConfigDict(dict(a=2)) # this.c = this.get_ref('a') # # # Say we update `this` with `other`. The links between fields # # in `other` could be lost in `this`. # this.update(other) # # # It is unclear what `this.b` should be when `this.a` is updated. # this.a = 10 # # this.b? raise TypeError('Cannot update a FieldReference from another ' 'FieldReference') else: self[key] = value def update_from_flattened_dict(self, flattened_dict, strip_prefix=''): """In-place updates values taken from a flattened dict. This allows a potentially nested source `ConfigDict` of the following form:: cfg = ConfigDict({ 'a': 1, 'b': { 'c': { 'd': 2 } } }) to be updated from a dict containing paths navigating to child items, of the following form:: updates = { 'a': 2, 'b.c.d': 3 } This filters `paths_dict` to only contain paths starting with `strip_prefix` and strips the prefix when applying the update. For example, consider we have the following values returned as flags:: flags = { 'flag1': x, 'flag2': y, 'config': 'some_file.py', 'config.a.b': 1, 'config.a.c': 2 } config = ConfigDict({ 'a': { 'b': 0, 'c': 0 } }) config.update_from_flattened_dict(flags, 'config.') Then we will now have:: config = ConfigDict({ 'a': { 'b': 1, 'c': 2 } }) Args: flattened_dict: A mapping (key path) -> value. strip_prefix: A prefix to be stripped from `path`. If specified, only paths matching `strip_prefix` will be processed. Raises: KeyError: if any of the key paths can't be found. """ if strip_prefix: interesting_items = { key: value for key, value in six.iteritems(flattened_dict) if key.startswith(strip_prefix) } else: interesting_items = flattened_dict # Keep track of any children that we want to update. Make sure that we # recurse into each one only once. children_to_update = set() for full_key, value in six.iteritems(interesting_items): key = full_key[len(strip_prefix):] if strip_prefix else full_key if '.' in key: # If the path is hierarchical, we'll need to tell the first component # to update itself. child = key.split('.')[0] if child in self: if isinstance(self[child], ConfigDict): children_to_update.add(child) else: raise KeyError('Key "{}" cannot be updated as "{}" is not a ' 'ConfigDict.'.format(full_key, strip_prefix + child)) else: raise KeyError('Key "{}" cannot be set as "{}" was not found.' .format(full_key, strip_prefix + child)) else: self[key] = value for child in children_to_update: child_strip_prefix = strip_prefix + child + '.' self[child].update_from_flattened_dict(interesting_items, child_strip_prefix) def _frozenconfigdict_valid_input(obj, ancestor_list=None): """Raises error if obj is NOT a valid input for FrozenConfigDict. Args: obj: Object to check. In the first call (with ancestor_list = None), obj should be of type ConfigDict. During recursion, it may be any type except dict. ancestor_list: List of ancestors of obj in the attribute/element structure, used to detect reference cycles in obj. Raises: ValueError: If obj is an invalid input for FrozenConfigDict, i.e. if it contains a dict within a list/tuple or contains a reference cycle. Also if obj is a dict, which means it wasn't already converted to ConfigDict. """ ancestor_list = ancestor_list or [] # Dicts must be converted to ConfigDict before _frozenconfigdict_valid_input() assert not isinstance(obj, dict) if id(obj) in ancestor_list: raise ValueError('Bad FrozenConfigDict initialization: Cannot contain a ' 'cycle in its reference structure.') ancestor_list.append(id(obj)) if isinstance(obj, ConfigDict): for value in obj.values(): _frozenconfigdict_valid_input(value, ancestor_list) elif isinstance(obj, FieldReference): _frozenconfigdict_valid_input(obj.get(), ancestor_list) elif isinstance(obj, (list, tuple)): for element in obj: if isinstance(element, (dict, ConfigDict, FieldReference)): raise ValueError('Bad FrozenConfigDict initialization: Cannot ' 'contain a dict, ConfigDict, or FieldReference ' 'within a list or tuple.') _frozenconfigdict_valid_input(element, ancestor_list) ancestor_list.pop() def _tuple_to_immutable(value, visit_map): """Convert tuple to fully immutable tuple. Args: value: Tuple to be made fully immutable (including its elements). visit_map: As used elsewhere. See _frozenconfigdict_fill_seed() documentation. Must not contain id(value) as a key (if it does, an immutable version of value already exists). Returns: immutable_value: Immutable version of value, created with minimal copying (for example, if a value contains no mutable elements, it is returned untouched). same_value: Whether the same value was returned untouched, i.e. with the same memory address. Boolean. visit_map: Updated visit_map Raises: TypeError: If one of the following: 1) value is not a tuple. 2) value contains a dict, ConfigDict, or FieldReference. If it does, value is an invalid attribute of FrozenConfigDict, and this should have been caught in valid_input at initialization. ValueError: id(value) is in visit_map. """ # Ensure there are no cycles assert id(value) not in visit_map value_copy = [] same_value = True for element in value: # Sanity check: element cannot be dict, ConfigDict, or FieldReference assert not isinstance(element, (dict, ConfigDict, FieldReference)) if isinstance(element, (list, tuple, set)): new_element, uncopied_element, visit_map = _convert_to_immutable( element, visit_map) same_value &= uncopied_element value_copy.append(new_element) else: value_copy.append(element) if same_value: return value, True, visit_map else: return tuple(value_copy), False, visit_map def _convert_to_immutable(value, visit_map): """Convert Python built-in type to immutable, copying if necessary. Args: value: To be made immutable type (including its elements). Must have type list, tuple, or set. visit_map: As used elsewhere. See _frozenconfigdict_fill_seed() documentation. Returns: immutable_value: Immutable version of value, created with minimal copying. same_value: Whether the same value was returned untouched, i.e. with the same memory address. Boolean. visit_map: Updated visit_map. Raises: TypeError: If value is an invalid type (not a list, tuple, or set). """ value_id = id(value) if value_id in visit_map: return visit_map[value_id], True, visit_map same_value = False if isinstance(value, set): immutable_value = frozenset(value) elif isinstance(value, tuple): immutable_value, same_value, visit_map = _tuple_to_immutable( value, visit_map) elif isinstance(value, list): immutable_value, _, visit_map = _tuple_to_immutable(tuple(value), visit_map) else: # Type-check the input assert False visit_map[value_id] = immutable_value return immutable_value, same_value, visit_map def _frozenconfigdict_fill_seed(seed, initial_configdict, visit_map=None): """Fills an empty FrozenConfigDict without copying previously visited nodes. Turns seed (an empty FrozenConfigDict) into a FrozenConfigDict version of initial_configdict. Avoids duplicating nodes of initial_configdict because if a value of initial_configdict has been previously visited, that value is not re-converted to a FrozenConfigDict. If a FieldReference is encountered which contains a dict, its contents will be converted to FrozenConfigDict. Note: As described in the __init__() documentation, this will not replicate the structure of initial_configdict if it contains self-references within lists, tuples, or other types. There is no warning or error in this case. Args: seed: Empty FrozenConfigDict, to be filled in. initial_configdict: The template on which seed is built. Must be of type ConfigDict. visit_map: Dictionary from memory addresses to values, storing the FrozenConfigDict versions of dictionary values. Lists which have been converted to tuples and sets to frozensets are also stored in visit_map to preserve the reference structure of initial_configdict. visit_map need not contain (id(initial_configdict), seed) as a key/value pair. Raises: ValueError: If one of the following, both of which can never happen in practice: 1) seed is not an empty FrozenConfigDict. 2) initial_configdict contains a FieldReference. """ # These should be impossible to raise, since the public call-site in # __init__() pass in valid input, as does this method recursively. assert isinstance(seed, FrozenConfigDict) assert not seed # This is where we define self._configdict for the FrozenConfigDict() class. # It is defined here instead of in FrozenConfigDict.__init__() because we must # fill in an empty FrozenConfigDict but do not want to have an unexpected # signature for FrozenConfigDict.__init__() by passing it initial_configdict. object.__setattr__(seed, '_configdict', initial_configdict) visit_map = visit_map or {} visit_map[id(initial_configdict)] = seed for key, value in six.iteritems(initial_configdict): # This should never be raised due to elimination of references by # ConfigDict's iteritems if isinstance(value, FieldReference): raise ValueError('Trying to initialize a FrozenConfigDict value with ' 'a FieldReference. This should never happen, please ' 'file a bug.') if id(value) in visit_map: value = visit_map[id(value)] elif (isinstance(value, ConfigDict) and not isinstance(value, FrozenConfigDict)): value_frozenconfigdict = FrozenConfigDict(type_safe=value.is_type_safe) _frozenconfigdict_fill_seed(value_frozenconfigdict, value, visit_map) value = value_frozenconfigdict seed._frozen_setattr(key, value, # pylint:disable=protected-access visit_map) class FrozenConfigDict(ConfigDict): """Immutable and hashable type of ConfigDict. See ConfigDict() documentation above for details and usage. FrozenConfigDict is fully immutable. It contains no lists or sets (at initialization, lists and sets are converted to tuples and frozensets). The only potential sources of mutability are attributes with custom types, which are not touched. It is recommended to convert a ConfigDict to FrozenConfigDict after construction if possible. """ def __init__(self, initial_dictionary=None, type_safe=True): """Creates an instance of FrozenConfigDict. Lists and sets are copied into tuples and frozensets. However, copying is kept to a minimum so tuples, frozensets, and other immutable types are not copied unless they contain mutable types. Prohibited initial_dictionary structures: initial_dictionary may not contain any lists or tuples with dictionary, ConfigDict, or FieldReference elements, or else an error is raised at initialization. It also may not contain loops in the reference structure, i.e. the reference structure must be a Directed Acyclic Graph. This includes loops in list-element and tuple-element references. initial_dictionary's reference structure need not be a tree. Warning: Unexpected behavior may occur with types other than Python's built-in types. See ConfigDict() documentation for details. Warning: As with ConfigDict, FieldReference values may be changed. If initial_dictionary contains a FieldReference with a value of type dict or ConfigDict, that value will be converted to FrozenConfigDict. Args: initial_dictionary: May be one of the following: 1) dict. In this case all values of initial_dictionary that are dictionaries are also converted to FrozenConfigDict. If there are dictionaries contained in lists or tuples, an error is raised. 2) ConfigDict. In this case all ConfigDict attributes are also converted to FrozenConfigDict. 3) FrozenConfigDict. In this case all attributes are uncopied, and only the top-level object (self) is re-addressed. type_safe: See ConfigDict documentation. Note that this only matters if the FrozenConfigDict is converted to ConfigDict at some point. """ super(FrozenConfigDict, self).__init__() initial_configdict = ConfigDict(initial_dictionary=initial_dictionary, type_safe=type_safe) _frozenconfigdict_valid_input(initial_configdict) # This will define the self._configdict attribute _frozenconfigdict_fill_seed(self, initial_configdict) object.__setattr__(self, '_locked', initial_configdict.is_locked) object.__setattr__(self, '_type_safe', initial_configdict.is_type_safe) def _frozen_setattr(self, key, value, visit_map=None): """Sets attribute, analogous to __setattr__(). Args: key: Name of the attribute to set. value: Value of the attribute to set. visit_map: Dictionary from memory addresses to values, storing the FrozenConfigDict versions of value's elements. Lists which have been converted to tuples and sets to frozensets are also stored in visit_map. Returns: visit_map: Updated visit_map. Raises: ValueError: If there is a dot in key, or value contains dicts inside lists or tuples. Also if key is already an attribute, since redefining an attribute is prohibited for FrozenConfigDict. AttributeError: If key is protected (such as '_type_safe' and '_locked'). """ visit_map = visit_map or {} # These should always pass because of conversion to ConfigDict at # initialization self._ensure_mutability(key) assert '.' not in key if key in self._fields: raise ValueError('Cannot redefine attribute of FrozenConfigDict.') if isinstance(value, (list, tuple, set)): immutable_value, _, visit_map = _convert_to_immutable(value, visit_map) self._fields[key] = immutable_value else: self._fields[key] = value return visit_map def __setstate__(self, state): """Recreates FrozenConfigDict from its dict representation.""" self.__init__(state['_configdict']) def __setattr__(self, attribute, value): raise AttributeError('FrozenConfigDict is immutable. Cannot call ' '__setattr__().') def __delattr__(self, attribute): raise AttributeError('FrozenConfigDict is immutable. Cannot call ' '__delattr__().') def __setitem__(self, attribute, value): raise AttributeError('FrozenConfigDict is immutable. Cannot call ' '__setitem__().') def __delitem__(self, attribute): raise AttributeError('FrozenConfigDict is immutable. Cannot call ' '__delitem__().') def lock(self): raise AttributeError('FrozenConfigDict is immutable. Cannot call lock().') def unlock(self): raise AttributeError('FrozenConfigDict is immutable. Cannot call unlock().') def __hash__(self): """Computes hash. The hash depends not only on the immutable aspects of the FrozenConfigDict, but on the types of the initial_dictionary at initialization (i.e. on the _configdict attribute). For example, in the following, hash(frozen_1) will not equal hash(frozen_2): d_1 = {'a': (1, )} d_2 = {'a': [1]} frozen_1 = FrozenConfigDict(d_1) frozen_2 = FrozenConfigDict(d_2) Note: This implementation relies on the particulars of the FrozenConfigDict structure. For example, the fact that lists and tuples cannot contain dicts or ConfigDicts is crucial, as is the fact that any loop in the reference structure is prohibited. Note: Due to hash randomization, this hash will likely differ in different Python sessions. For comparisons across sessions, please directly use equality of the serialization. For more, see https://bugs.python.org/issue13703 Returns: frozenconfigdict_hash: The hash value. Raises: TypeError: self contains an unhashable type. """ def value_hash(value): """Hashes a single value.""" if isinstance(value, set): return hash((hash(frozenset(value)), 1)) elif isinstance(value, (list, tuple)): value_hash_list = [isinstance(value, list)] for element in value: element_hash = value_hash(element) value_hash_list.append(element_hash) return hash(tuple(value_hash_list)) elif isinstance(value, FieldReference): return value_hash(value.get()) else: try: return hash(value) except TypeError: raise TypeError('FrozenConfigDict contains unhashable type ' '{}'.format(type(value))) fields_hash = 0 for key, value in six.iteritems(self._fields): if isinstance(value, FrozenConfigDict): fields_hash += hash((hash(key), hash(value))) else: # Use self._configdict value to ensure attending to mutability fields_hash += hash((hash(key), value_hash(self._configdict._fields[key]))) frozenconfigdict_hash = hash((fields_hash, self.is_locked, self.is_type_safe)) return frozenconfigdict_hash def __eq__(self, other): """Override default Equals behavior. Like __hash__(), this pays attention to the type of initial_dictionary. See __hash__() documentation for details. Warning: This distinguishes FrozenConfigDict from ConfigDict. For example: cd = ConfigDict() fcd = FrozenConfigDict() fcd.__eq__(cd) # Returns False Args: other: Object to compare self to. Returns: same: Boolean self == other. """ if isinstance(other, FrozenConfigDict): return ConfigDict(self) == ConfigDict(other) else: return False def as_configdict(self): return self._configdict class CustomJSONEncoder(json.JSONEncoder): """JSON encoder for ConfigDict and FieldReference. The encoder throws an exception for non-supported types. """ def default(self, obj): if isinstance(obj, FieldReference): return obj.get() elif isinstance(obj, ConfigDict): return obj._fields else: raise TypeError('{} is not JSON serializable. Instead use ' 'ConfigDict.to_json_best_effort()'.format(type(obj))) class _BestEffortCustomJSONEncoder(CustomJSONEncoder): """Best effort JSON encoder for ConfigDict. The encoder tries to serialize non-supported types (doesn't throw exceptions). """ def default(self, obj): try: return super(_BestEffortCustomJSONEncoder, self).default(obj) except TypeError: if isinstance(obj, set): return sorted(list(obj)) elif inspect.isfunction(obj): return 'function {}'.format(obj.__name__) elif (hasattr(obj, '__dict__') and obj.__dict__ and not inspect.isclass(obj)): # Instantiated object's variables return dict(obj.__dict__) elif hasattr(obj, '__str__'): return 'unserializable object: {}'.format(obj) else: return 'unserializable object of type: {}'.format(type(obj)) def create(**kwargs): """Creates a `ConfigDict` with the given named arguments as key-value pairs. This allows for simple dictionaries whose elements can be accessed directly using field access:: from ml_collections.config_dict import config_dict point = config_dict.create(x=1, y=2) print(point.x, point.y) This is particularly useful for compactly writing nested configurations:: config = config_dict.create( data=config_dict.create( game='freeway', frame_size=100), model=config_dict.create(num_hidden=1000)) The reason for the existence of this function is that it simplifies the code required for the majority of the use cases of `ConfigDict`, compared to using either `ConfigDict` or `namedtuple's`. Examples of such use cases include training script configuration, and returning multiple named values. Args: **kwargs: key-value pairs to be stored in the `ConfigDict`. Returns: A `ConfigDict` containing the key-value pairs in `kwargs`. """ return ConfigDict(initial_dictionary=kwargs) # TODO(sergomez): make placeholders required by default. def placeholder(field_type, required=False): """Defines an entry in a ConfigDict that has no value yet. Example:: config = configdict.create( batch_size = configdict.placeholder(int), frame_shape = configdict.placeholder(tf.TensorShape)) Args: field_type: type of value. required: If True, the placeholder will raise an error on access if the underlying value hasn't been set. Returns: A `FieldReference` with value None and the given type. """ return FieldReference(None, field_type=field_type, required=required) def required_placeholder(field_type): """Defines an entry in a ConfigDict with unknown but required value. Example:: config = configdict.create( batch_size = configdict.required_placeholder(int)) try: print(config.batch_size) except RequiredValueError: pass config.batch_size = 10 print(config.batch_size) # 10 Args: field_type: type of value. Returns: A `FieldReference` with value None and the given type. """ return placeholder(field_type, required=True) def recursive_rename(conf, old_name, new_name): """Returns copy of conf with old_name recursively replaced by new_name. This is not done in place, no changes are made to conf but a new ConfigDict is returned with the changes made. This is useful if the name of a parameter has been changed in code but you need to load an old config. Example usage: updated_conf = configdict.recursive_rename(conf, "config", "kwargs") Args: conf: a ConfigDict which needs updating old_name: the name used in the ConfigDict which is out of sync with the code new_name: the name used in the code Returns: A ConfigDict which is a copy of conf but with all instances of old_name replaced with new_name. """ new_conf = ConfigDict() for name, c in conf.items(): if isinstance(c, ConfigDict): new_c = recursive_rename(c, old_name, new_name) else: new_c = c if name == old_name: setattr(new_conf, new_name, new_c) else: setattr(new_conf, name, new_c) return new_conf ml_collections-0.1.1/ml_collections/config_dict/__init__.py0000640000175000017500000000260114174507605023440 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Classes for defining configurations of experiments and models.""" from .config_dict import _Op from .config_dict import ConfigDict from .config_dict import create from .config_dict import CustomJSONEncoder from .config_dict import FieldReference from .config_dict import FrozenConfigDict from .config_dict import JSONDecodeError from .config_dict import MutabilityError from .config_dict import placeholder from .config_dict import recursive_rename from .config_dict import required_placeholder from .config_dict import RequiredValueError __all__ = ("_Op", "ConfigDict", "create", "CustomJSONEncoder", "FieldReference", "FrozenConfigDict", "JSONDecodeError", "MutabilityError", "placeholder", "recursive_rename", "required_placeholder", "RequiredValueError") ml_collections-0.1.1/ml_collections/config_dict/examples/0000750000175000017500000000000014174510450023135 5ustar nileshnileshml_collections-0.1.1/ml_collections/config_dict/examples/config_dict_advanced.py0000640000175000017500000000765114174507605027626 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of ConfigDict usage. This example includes loading a ConfigDict in FLAGS, locking it, type safety, iteration over fields, checking for a particular field, unpacking with `**`, and loading dictionary from string representation. """ from absl import app from ml_collections.config_flags import config_flags import yaml _CONFIG = config_flags.DEFINE_config_file( 'my_config', default='ml_collections/config_dict/examples/config.py') def dummy_function(string, **unused_kwargs): return 'Hello {}'.format(string) def print_section(name): print() print() print('-' * len(name)) print(name.upper()) print('-' * len(name)) print() def main(_): # Config is already loaded in FLAGS.my_config due to the logic hidden # in app.run(). config = _CONFIG.value print_section('Printing config.') print(config) # Config is of our type ConfigDict. print('Type of the config {}'.format(type(config))) # By default it is locked, thus you cannot add new fields. # This prevents you from misspelling your attribute name. print_section('Locking.') print('config.is_locked={}'.format(config.is_locked)) try: config.object.new_field = -3 except AttributeError as e: print(e) # There is also "did you mean" feature! try: config.object.floet = -3. except AttributeError as e: print(e) # However if you want to modify it you can always unlock. print_section('Unlocking.') with config.unlocked(): config.object.new_field = -3 print('config.object.new_field={}'.format(config.object.new_field)) # By default config is also type-safe, so you cannot change the type of any # field. print_section('Type safety.') try: config.float = 'jerry' except TypeError as e: print(e) config.float = -1.2 print('config.float={}'.format(config.float)) # NoneType is ignored by type safety and can both override and be overridden. config.float = None config.float = -1.2 # You can temporarly turn type safety off. with config.ignore_type(): config.float = 'tom' print('config.float={}'.format(config.float)) config.float = 2.3 print('config.float={}'.format(config.float)) # You can use ConfigDict as a regular dict in many typical use-cases: # Iteration over fields: print_section('Iteration over fields.') for field in config: print('config has field "{}"'.format(field)) # Checking if it contains a particular field using the "in" command. print_section('Checking for a particular field.') for field in ('float', 'non_existing'): if field in config: print('"{}" is in config'.format(field)) else: print('"{}" is not in config'.format(field)) # Using ** unrolling to pass the config to a function as named arguments. print_section('Unpacking with **') print(dummy_function(**config)) # You can even load a dictionary (notice it is not ConfigDict anymore) from # a yaml string representation of ConfigDict. # Note: __repr__ (not __str__) is the recommended representation, as it # preserves FieldReferences and placeholders. print_section('Loading dictionary from string representation.') dictionary = yaml.load(repr(config), yaml.UnsafeLoader) print('dict["object_reference"]["dict"]["dict"]["float"]={}'.format( dictionary['object_reference']['dict']['dict']['float'])) if __name__ == '__main__': app.run(main) ml_collections-0.1.1/ml_collections/config_dict/examples/config_dict_placeholder.py0000640000175000017500000000335014174507605030333 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of placeholder fields in a ConfigDict. This example shows how ConfigDict placeholder fields work. For a more complete example of ConfigDict features, see example_advanced. """ from absl import app import ml_collections def main(_): placeholder = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict() cfg.placeholder = placeholder cfg.optional = ml_collections.FieldReference(0, field_type=int) cfg.nested = ml_collections.ConfigDict() cfg.nested.placeholder = placeholder try: cfg.optional = 'tom' # Raises Type error as this field is an integer. except TypeError as e: print(e) cfg.optional = 1555 # Works fine. cfg.placeholder = 1 # Changes the value of both placeholder and # nested.placeholder fields. # Note that the indirection provided by FieldReferences will be lost if # accessed through a ConfigDict: placeholder = ml_collections.FieldReference(0) cfg.field1 = placeholder cfg.field2 = placeholder # This field will be tied to cfg.field1. cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. print(cfg) if __name__ == '__main__': app.run() ml_collections-0.1.1/ml_collections/config_dict/examples/config.py0000640000175000017500000001043314174507605024766 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of a config file using ConfigDict. The idea of this configuration file is to show a typical use case of ConfigDict, as well as its limitations. This also exemplifies a self-referencing ConfigDict. """ import copy import ml_collections def _get_flat_config(): """Helper to generate simple config without references.""" # The suggested way to create a ConfigDict() is to call its constructor # and assign all relevant fields. config = ml_collections.ConfigDict() # In order to add new attributes you can just use . notation, like with any # python object. They will be tracked by ConfigDict, and you get type checking # etc. for free. config.integer = 23 config.float = 2.34 config.string = 'james' config.bool = True # It is possible to assign dictionaries to ConfigDict and they will be # automatically and recursively wrapped with ConfigDict. However, make sure # that the dict you are assigning does not use internal references/cycles as # this is not supported. Instead, create the dicts explicitly as demonstrated # by get_config(). But note that this operation makes an element-by-element # copy of your original dict. # Also note that the recursive wrapping on input dictionaries with ConfigDict # does not extend through non-dictionary types (including basic Python types # and custom classes). This causes unexpected behavior most commonly if a # value is a list of dictionaries, so avoid giving ConfigDict such inputs. config.dict = { 'integer': 1, 'float': 3.14, 'string': 'mark', 'bool': False, 'dict': { 'float': 5 } } return config def get_config(): """Returns a ConfigDict instance describing a complex config. Returns: A ConfigDict instance with the structure: ``` CONFIG-+-- integer |-- float |-- string |-- bool |-- dict +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- float | |-- object +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- float | |-- object_copy +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- float | |-- object_reference [reference pointing to CONFIG-+--object] ``` """ config = _get_flat_config() config.object = _get_flat_config() # References work just fine, so you will be able to override both # values at the same time. The rule is the same as for python objects, # everything that is mutable is passed as a reference, thus it will not work # with assigning integers or strings, but will work just fine with # ConfigDicts. # WARNING: Each time you assign a dictionary as a value it will create a new # instance of ConfigDict in memory, thus it will be a copy of the original # dict and not a reference to the original. config.object_reference = config.object # ConfigDict supports deepcopying. config.object_copy = copy.deepcopy(config.object) return config ml_collections-0.1.1/ml_collections/config_dict/examples/config_dict_lock.py0000640000175000017500000000245714174507605027010 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of ConfigDict usage of lock. This example shows the roles and scopes of ConfigDict's lock(). """ from absl import app import ml_collections def main(_): cfg = ml_collections.ConfigDict() cfg.integer_field = 123 # Locking prohibits the addition and deletion of new fields but allows # modification of existing values. Locking happens automatically during # loading through flags. cfg.lock() try: cfg.intagar_field = 124 # Raises AttributeError and suggests valid field. except AttributeError as e: print(e) cfg.integer_field = -123 # Works fine. with cfg.unlocked(): cfg.intagar_field = 1555 # Works fine too. print(cfg) if __name__ == '__main__': app.run() ml_collections-0.1.1/ml_collections/config_dict/examples/config_dict_initialization.py0000640000175000017500000000630314174507605031101 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of initialization features and gotchas in a ConfigDict. """ import copy from absl import app import ml_collections def print_section(name): print() print() print('-' * len(name)) print(name.upper()) print('-' * len(name)) print() def main(_): inner_dict = {'list': [1, 2], 'tuple': (1, 2, [3, 4], (5, 6))} example_dict = { 'string': 'tom', 'int': 2, 'list': [1, 2], 'set': {1, 2}, 'tuple': (1, 2), 'ref': ml_collections.FieldReference({'int': 0}), 'inner_dict_1': inner_dict, 'inner_dict_2': inner_dict } print_section('Initializing on dictionary.') # ConfigDict can be initialized on example_dict example_cd = ml_collections.ConfigDict(example_dict) # Dictionary fields are also converted to ConfigDict print(type(example_cd.inner_dict_1)) # And the reference structure is preserved print(id(example_cd.inner_dict_1) == id(example_cd.inner_dict_2)) print_section('Initializing on ConfigDict.') # ConfigDict can also be initialized on a ConfigDict example_cd_cd = ml_collections.ConfigDict(example_cd) # Yielding the same result: print(example_cd == example_cd_cd) # Note that the memory addresses are different print(id(example_cd) == id(example_cd_cd)) # The memory addresses of the attributes are not the same because of the # FieldReference, which gets removed on the second initialization list_to_ids = lambda x: [id(element) for element in x] print( set(list_to_ids(list(example_cd.values()))) == set( list_to_ids(list(example_cd_cd.values())))) print_section('Initializing on self-referencing dictionary.') # Initialization works on a self-referencing dict self_ref_dict = copy.deepcopy(example_dict) self_ref_dict['self'] = self_ref_dict self_ref_cd = ml_collections.ConfigDict(self_ref_dict) # And the reference structure is replicated print(id(self_ref_cd) == id(self_ref_cd.self)) print_section('Unexpected initialization behavior.') # ConfigDict initialization doesn't look inside lists, so doesn't convert a # dict in a list to ConfigDict dict_in_list_in_dict = {'list': [{'troublemaker': 0}]} dict_in_list_in_dict_cd = ml_collections.ConfigDict(dict_in_list_in_dict) print(type(dict_in_list_in_dict_cd.list[0])) # This can cause the reference structure to not be replicated referred_dict = {'key': 'value'} bad_reference = {'referred_dict': referred_dict, 'list': [referred_dict]} bad_reference_cd = ml_collections.ConfigDict(bad_reference) print(id(bad_reference_cd.referred_dict) == id(bad_reference_cd.list[0])) if __name__ == '__main__': app.run() ml_collections-0.1.1/ml_collections/config_dict/examples/field_reference.py0000640000175000017500000001132414174507605026622 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of FieldReference usage. This shows how to use FieldReferences for lazy computation. """ from absl import app import ml_collections from ml_collections.config_dict import config_dict def lazy_computation(): """Simple example of lazy computation with `configdict.FieldReference`.""" ref = ml_collections.FieldReference(1) print(ref.get()) # Prints 1 add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 # Addition is lazily computed for FieldReferences so changing ref will change # the value that is used to compute add_ten. ref.set(5) print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 def change_lazy_computation(): """Overriding lazily computed values.""" config = ml_collections.ConfigDict() config.reference = 1 config.reference_0 = config.get_ref('reference') + 10 config.reference_1 = config.get_ref('reference') + 20 config.reference_1_0 = config.get_ref('reference_1') + 100 print(config.reference) # Prints 1. print(config.reference_0) # Prints 11. print(config.reference_1) # Prints 21. print(config.reference_1_0) # Prints 121. config.reference_1 = 30 print(config.reference) # Prints 1 (unchanged). print(config.reference_0) # Prints 11 (unchanged). print(config.reference_1) # Prints 30. print(config.reference_1_0) # Prints 130. def create_cycle(): """Creates a cycle within a ConfigDict.""" config = ml_collections.ConfigDict() config.integer_field = 1 config.bigger_integer_field = config.get_ref('integer_field') + 10 try: # Raises a MutabilityError because setting config.integer_field would # cause a cycle. config.integer_field = config.get_ref('bigger_integer_field') + 2 except config_dict.MutabilityError as e: print(e) def lazy_configdict(): """Example usage of lazy computation with ConfigDict.""" config = ml_collections.ConfigDict() config.reference_field = ml_collections.FieldReference(1) config.integer_field = 2 config.float_field = 2.5 # No lazy evaluatuations because we didn't use get_ref() config.no_lazy = config.integer_field * config.float_field # This will lazily evaluate ONLY config.integer_field config.lazy_integer = config.get_ref('integer_field') * config.float_field # This will lazily evaluate ONLY config.float_field config.lazy_float = config.integer_field * config.get_ref('float_field') # This will lazily evaluate BOTH config.integer_field and config.float_Field config.lazy_both = (config.get_ref('integer_field') * config.get_ref('float_field')) config.integer_field = 3 print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value print(config.lazy_integer) # Prints 7.5 config.float_field = 3.5 print(config.lazy_float) # Prints 7.0 print(config.lazy_both) # Prints 10.5 def lazy_configdict_advanced(): """Advanced lazy computation with ConfigDict.""" # FieldReferences can be used with ConfigDict as well config = ml_collections.ConfigDict() config.float_field = 12.6 config.integer_field = 123 config.list_field = [0, 1, 2] config.float_multiply_field = config.get_ref('float_field') * 3 print(config.float_multiply_field) # Prints 37.8 config.float_field = 10.0 print(config.float_multiply_field) # Prints 30.0 config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] config.list_field = [-1] print(config.longer_list_field) # Prints [-1, 3, 4, 5] # Both operands can be references config.ref_subtraction = ( config.get_ref('float_field') - config.get_ref('integer_field')) print(config.ref_subtraction) # Prints -113.0 config.integer_field = 10 print(config.ref_subtraction) # Prints 0.0 def main(argv=()): del argv # Unused. lazy_computation() lazy_configdict() change_lazy_computation() create_cycle() lazy_configdict_advanced() if __name__ == '__main__': app.run() ml_collections-0.1.1/ml_collections/config_dict/examples/frozen_config_dict.py0000640000175000017500000000576714174507605027372 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of basic FrozenConfigDict usage. This example shows the most basic usage of FrozenConfigDict, highlighting the differences between FrozenConfigDict and ConfigDict and including converting between the two. """ from absl import app import ml_collections def print_section(name): print() print() print('-' * len(name)) print(name.upper()) print('-' * len(name)) print() def main(_): print_section('Attribute Types.') cfg = ml_collections.ConfigDict() cfg.int = 1 cfg.list = [1, 2, 3] cfg.tuple = (1, 2, 3) cfg.set = {1, 2, 3} cfg.frozenset = frozenset({1, 2, 3}) cfg.dict = { 'nested_int': 4, 'nested_list': [4, 5, 6], 'nested_tuple': ([4], 5, 6), } print('Types of cfg fields:') print('list: ', type(cfg.list)) # List print('set: ', type(cfg.set)) # Set print('nested_list: ', type(cfg.dict.nested_list)) # List print('nested_tuple[0]: ', type(cfg.dict.nested_tuple[0])) # List frozen_cfg = ml_collections.FrozenConfigDict(cfg) print('\nTypes of FrozenConfigDict(cfg) fields:') print('list: ', type(frozen_cfg.list)) # Tuple print('set: ', type(frozen_cfg.set)) # Frozenset print('nested_list: ', type(frozen_cfg.dict.nested_list)) # Tuple print('nested_tuple[0]: ', type(frozen_cfg.dict.nested_tuple[0])) # Tuple cfg_from_frozen = ml_collections.ConfigDict(frozen_cfg) print('\nTypes of ConfigDict(FrozenConfigDict(cfg)) fields:') print('list: ', type(cfg_from_frozen.list)) # List print('set: ', type(cfg_from_frozen.set)) # Set print('nested_list: ', type(cfg_from_frozen.dict.nested_list)) # List print('nested_tuple[0]: ', type(cfg_from_frozen.dict.nested_tuple[0])) # List print('\nCan use FrozenConfigDict.as_configdict() to convert to ConfigDict:') print(cfg_from_frozen == frozen_cfg.as_configdict()) # True print_section('Immutability.') try: frozen_cfg.new_field = 1 # Raises AttributeError because of immutability. except AttributeError as e: print(e) print_section('"==" and eq_as_configdict().') # FrozenConfigDict.__eq__() is not type-invariant with respect to ConfigDict print(frozen_cfg == cfg) # False # FrozenConfigDict.eq_as_configdict() is type-invariant with respect to # ConfigDict print(frozen_cfg.eq_as_configdict(cfg)) # True # .eq_as_congfigdict() is also a method of ConfigDict print(cfg.eq_as_configdict(frozen_cfg)) # True if __name__ == '__main__': app.run(main) ml_collections-0.1.1/ml_collections/config_dict/examples/config_dict_basic.py0000640000175000017500000000270514174507605027135 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of basic ConfigDict usage. This example shows the most basic usage of ConfigDict, including type safety. For examples of more features, see example_advanced. """ from absl import app import ml_collections def main(_): cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg.integer_field) # Prints 123. print(cfg['integer_field']) # Prints 123 as well. try: cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. except TypeError as e: print(e) cfg.float_field = 12 # Works: `int` types can be assigned to `float`. cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. print(cfg) if __name__ == '__main__': app.run(main) ml_collections-0.1.1/ml_collections/config_dict/examples/examples_test.py0000640000175000017500000000340214174507605026374 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for ConfigDict examples. Ensures that config_dict_basic, config_dict_initialization, config_dict_lock, config_dict_placeholder, field_reference, frozen_config_dict run successfully. """ from absl.testing import absltest from absl.testing import parameterized from ml_collections.config_dict.examples import config_dict_advanced from ml_collections.config_dict.examples import config_dict_basic from ml_collections.config_dict.examples import config_dict_initialization from ml_collections.config_dict.examples import config_dict_lock from ml_collections.config_dict.examples import config_dict_placeholder from ml_collections.config_dict.examples import field_reference from ml_collections.config_dict.examples import frozen_config_dict class ConfigDictExamplesTest(parameterized.TestCase): @parameterized.parameters(config_dict_advanced, config_dict_basic, config_dict_initialization, config_dict_lock, config_dict_placeholder, field_reference, frozen_config_dict) def testScriptRuns(self, example_name): example_name.main(None) if __name__ == '__main__': absltest.main() ml_collections-0.1.1/ml_collections/config_flags/0000750000175000017500000000000014174510450021470 5ustar nileshnileshml_collections-0.1.1/ml_collections/config_flags/config_flags.py0000640000175000017500000007255014174507605024505 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Configuration commmand line parser.""" import errno import imp import os import re import sys import traceback from typing import Any, Dict, Generic, List, MutableMapping, Optional, Tuple, Type, TypeVar from absl import flags from absl import logging import dataclasses import ml_collections from ml_collections.config_flags import tuple_parser import six from six import string_types FLAGS = flags.FLAGS # We need this to work both for Python 2 and 3. # The initial code in Python 2 was: # _FIELD_TYPE_TO_PARSER = { # types.IntType: flags.IntegerParser(), # types.FloatType: flags.FloatParser(), # OK For Python 3 # types.BooleanType: flags.BooleanParser(), # OK For Python 3 # types.StringType: flags.ArgumentParser(), # types.TupleType: tuple_parser.TupleParser(), # OK For Python 3 # } # The possible breaking changes are: # - A Python 3 int could be a Python 2 long, which was not previously supported. # We then add support for long. # - Only Python 2 str were supported (not unicode). Python 3 will behave the # same with the str semantic change. _FIELD_TYPE_TO_PARSER = { float: flags.FloatParser(), bool: flags.BooleanParser(), # Implementing a custom parser to override `Tuple` arguments. tuple: tuple_parser.TupleParser(), } for t in six.integer_types: _FIELD_TYPE_TO_PARSER[t] = flags.IntegerParser() for t in six.string_types: _FIELD_TYPE_TO_PARSER[t] = flags.ArgumentParser() _FIELD_TYPE_TO_PARSER[str] = flags.ArgumentParser() _FIELD_TYPE_TO_SERIALIZER = { t: flags.ArgumentSerializer() for t in _FIELD_TYPE_TO_PARSER } class UnsupportedOperationError(flags.Error): pass class FlagOrderError(flags.Error): pass class UnparsedFlagError(flags.Error): pass def DEFINE_config_file( # pylint: disable=g-bad-name name: str, default: Optional[str] = None, help_string: str = 'path to config file.', flag_values: flags.FlagValues = FLAGS, lock_config: bool = True, **kwargs) -> flags.FlagHolder: r"""Defines flag for `ConfigDict` files compatible with absl flags. The flag's value should be a path to a valid python file which contains a function called `get_config()` that returns a python object specifying a configuration. After the flag is parsed, `FLAGS.name` will contain a reference to this object, optionally with some values overridden. During flags parsing, every flag of form `--name.([a-zA-Z0-9]+\.?)+=value` and `-name.([a-zA-Z0-9]+\.?)+ value` will be treated as an override of a specific field in the config object returned by this flag. Field is essentially a dot delimited path inside the object where each path element has to be either an attribute or a key existing in the config object. For example `--my_config.field1.field2=val` means "assign value val to the attribute (or key) `field2` inside value of the attribute (or key) `field1` inside the value of `my_config` object". If there are both attribute and key-based access with the same name, attribute is preferred. Typical usage example: `script.py`:: from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file('my_config') print(FLAGS.my_config) `config.py`:: def get_config(): return { 'field1': 1, 'field2': 'tom', 'nested': { 'field': 2.23, }, } The following command:: python script.py -- --my_config=config.py --my_config.field1 8 --my_config.nested.field=2.1 will print:: {'field1': 8, 'field2': 'tom', 'nested': {'field': 2.1}} It is possible to parameterise the get_config function, allowing it to return a differently structured result for different occasions. This is particularly useful when setting up hyperparameter sweeps across various network architectures. `parameterised_config.py`:: def get_config(config_string): possible_configs = { 'mlp': { 'constructor': 'snt.nets.MLP', 'config': { 'output_sizes': (128, 128, 1), } }, 'lstm': { 'constructor': 'snt.LSTM', 'config': { 'hidden_size': 128, 'forget_bias': 1.0, } } } return possible_configs[config_string] If a colon is present in the command line override for the config file, everything to the right of the colon is passed into the get_config function. The following command lines will both function correctly:: python script.py -- --my_config=parameterised_config.py:mlp --my_config.config.output_sizes="(256,256,1)" python script.py -- --my_config=parameterised_config.py:lstm --my_config.config.hidden_size=256 The following will produce an error, as the hidden_size flag does not exist when the "mlp" config_string is provided:: python script.py -- --my_config=parameterised_config.py:mlp --my_config.config.hidden_size=256 Args: name: Flag name, optionally including extra config after a colon. default: Default value of the flag (default: None). help_string: Help string to display when --helpfull is called. (default: "path to config file.") flag_values: FlagValues instance used for parsing. (default: absl.flags.FLAGS) lock_config: If set to True, loaded config will be locked through calling .lock() method on its instance (if it exists). (default: True) **kwargs: Optional keyword arguments passed to Flag constructor. Returns: a handle to defined flag. """ parser = _ConfigFileParser(name=name, lock_config=lock_config) flag = _ConfigFlag( parser=parser, serializer=flags.ArgumentSerializer(), name=name, default=default, help_string=help_string, flag_values=flag_values, **kwargs) # Get the module name for the frame at depth 1 in the call stack. module_name = sys._getframe(1).f_globals.get('__name__', None) # pylint: disable=protected-access module_name = sys.argv[0] if module_name == '__main__' else module_name return flags.DEFINE_flag(flag, flag_values, module_name=module_name) def DEFINE_config_dict( # pylint: disable=g-bad-name name: str, config: ml_collections.ConfigDict, help_string: str = 'ConfigDict instance.', flag_values: flags.FlagValues = FLAGS, lock_config: bool = True, **kwargs) -> flags.FlagHolder: """Defines flag for inline `ConfigDict's` compatible with absl flags. Similar to `DEFINE_config_file` except the flag's value should be a `ConfigDict` instead of a path to a file containing a `ConfigDict`. After the flag is parsed, `FLAGS.name` will contain a reference to the `ConfigDict`, optionally with some values overridden. Typical usage example: `script.py`:: from absl import flags import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict({ 'field1': 1, 'field2': 'tom', 'nested': { 'field': 2.23, } }) FLAGS = flags.FLAGS config_flags.DEFINE_config_dict('my_config', config) ... print(FLAGS.my_config) The following command:: python script.py -- --my_config.field1 8 --my_config.nested.field=2.1 will print:: field1: 8 field2: tom nested: {field: 2.1} Args: name: Flag name. config: `ConfigDict` object. help_string: Help string to display when --helpfull is called. (default: "ConfigDict instance.") flag_values: FlagValues instance used for parsing. (default: absl.flags.FLAGS) lock_config: If set to True, loaded config will be locked through calling .lock() method on its instance (if it exists). (default: True) **kwargs: Optional keyword arguments passed to Flag constructor. Returns: a handle to defined flag. """ if not isinstance(config, ml_collections.ConfigDict): raise TypeError('config should be a ConfigDict') parser = _InlineConfigParser(name=name, lock_config=lock_config) flag = _ConfigFlag( parser=parser, serializer=flags.ArgumentSerializer(), name=name, default=config, help_string=help_string, flag_values=flag_values, **kwargs) # Get the module name for the frame at depth 1 in the call stack. module_name = sys._getframe(1).f_globals.get('__name__', None) # pylint: disable=protected-access module_name = sys.argv[0] if module_name == '__main__' else module_name return flags.DEFINE_flag(flag, flag_values, module_name=module_name) # Note that we would add a bound to constrain this to be a dataclass, except # that dataclasses don't have a specific base class, and structural typing for # attributes is currently (2021Q1) not supported in pytype (b/150927776). _T = TypeVar('_T') class _TypedFlagHolder(flags.FlagHolder, Generic[_T]): """A typed wrapper for a FlagHolder.""" def __init__(self, flag: flags.FlagHolder, ctor: Type[_T]): self._flag = flag self._ctor = ctor @property def value(self) -> _T: return self._ctor(**self._flag.value) @property def default(self) -> _T: return self._ctor(**self._flag.default) @property def name(self) -> str: return self._flag.name def DEFINE_config_dataclass( # pylint: disable=invalid-name name: str, config: _T, help_string: str = 'Configuration object. Must be a dataclass.', flag_values: flags.FlagValues = FLAGS, **kwargs, ) -> _TypedFlagHolder[_T]: """Defines a typed (dataclass) flag-overrideable configuration. Similar to `DEFINE_config_dict` except `config` should be a `dataclass`. Args: name: Flag name. config: A user-defined configuration object. Must be built via `dataclass`. help_string: Help string to display when --helpfull is called. flag_values: FlagValues instance used for parsing. **kwargs: Optional keyword arguments passed to Flag constructor. Returns: A handle to the defined flag. """ if not dataclasses.is_dataclass(config): raise ValueError('Configuration object must be a `dataclass`.') # Convert to configdict *without* recursing into leaf node(s). # If our config contains dataclasses (or other types) as fields, we want to # preserve them; dataclasses.asdict recursively turns all fields into dicts. dictionary = {field.name: getattr(config, field.name) for field in dataclasses.fields(config)} config_dict = ml_collections.ConfigDict(initial_dictionary=dictionary) # Define the flag. config_flag = DEFINE_config_dict( name, config=config_dict, help_string=help_string, flag_values=flag_values, **kwargs, ) return _TypedFlagHolder(flag=config_flag, ctor=config.__class__) def get_config_filename(config_flag) -> str: # pylint: disable=g-bad-name """Returns the path to the config file given the config flag. Args: config_flag: The flag instance obtained from FLAGS, e.g. FLAGS['config']. Returns: the path to the config file. """ if not is_config_flag(config_flag): raise TypeError('expect a config flag, found {}'.format(type(config_flag))) return config_flag.config_filename def get_override_values(config_flag) -> Dict[str, Any]: # pylint: disable=g-bad-name """Returns a flat dict containing overridden values from the config flag. Args: config_flag: The flag instance obtained from FLAGS, e.g. FLAGS['config']. Returns: a flat dict containing overridden values from the config flag. """ if not is_config_flag(config_flag): raise TypeError('expect a config flag, found {}'.format(type(config_flag))) return config_flag.override_values class _IgnoreFileNotFoundAndCollectErrors: """Helps recording "file not found" exceptions when loading config. Usage: ignore_errors = _IgnoreFileNotFoundAndCollectErrors() with ignore_errors.Attempt('Loading from foo', 'bar.id'): ... return True # successfully loaded from `foo` logging.error('Failed loading: {}'.format(ignore_errors.DescribeAttempts())) """ def __init__(self): self._attempts = [] # type: List[Tuple[Tuple[str, str], IOError]] def Attempt(self, description, path): """Creates a context manager that routes exceptions to this class.""" self._current_attempt = (description, path) ignore_errors = self class _ContextManager: def __enter__(self): return self def __exit__(self, exc_type, exc_value, unused_traceback): return ignore_errors.ProcessAttemptException(exc_type, exc_value) return _ContextManager() def ProcessAttemptException(self, exc_type, exc_value): expected_type = IOError if six.PY2 else FileNotFoundError # pylint: disable=undefined-variable if exc_type is expected_type and exc_value.errno == errno.ENOENT: self._attempts.append((self._current_attempt, exc_value)) # Returning a true value suppresses exceptions: # https://docs.python.org/2/reference/datamodel.html#object.__exit__ return True def DescribeAttempts(self): return '\n'.join( ' Attempted [{}]:\n {}\n {}'.format(attempt[0], attempt[1], e) for attempt, e in self._attempts) def _LoadConfigModule(name: str, path: str): """Loads a script from external file specified by path. Unprefixed path is looked for in the current working directory using regular file open operation. This should work with relative config paths. Args: name: Name of the new module. path: Path to the .py file containing the module. Returns: Module loaded from the given path. Raises: IOError: If the config file cannot be found. """ if not path: raise IOError('Path to config file is an empty string.') ignoring_errors = _IgnoreFileNotFoundAndCollectErrors() # Works for relative paths. with ignoring_errors.Attempt('Relative path', path): config_module = imp.load_source(name, path) return config_module # Nothing worked. Log the paths that were attempted. raise IOError('Failed loading config file {}\n{}'.format( name, ignoring_errors.DescribeAttempts())) class _ErrorConfig: """Dummy ConfigDict that raises an error on any attribute access.""" def __init__(self, error): super(_ErrorConfig, self).__init__() super(_ErrorConfig, self).__setattr__('_error', error) def __getattr__(self, attr): self._ReportError() def __setattr__(self, attr, value): self._ReportError() def __delattr__(self, attr): self._ReportError() def __getitem__(self, key): self._ReportError() def __setitem__(self, key, value): self._ReportError() def __delitem__(self, key): self._ReportError() def _ReportError(self): raise IOError('Configuration is not available because of an earlier ' 'failure to load: ' + # 'message' is not available in Python 3. getattr(self._error, 'message', str(self._error))) def _LockConfig(config): """Calls config.lock() if config has a lock method.""" if isinstance(config, _ErrorConfig): pass # Attempting to access _ErrorConfig.lock will raise its error. elif getattr(config, 'lock', None) and callable(config.lock): config.lock() else: pass # config.lock() does not have desired semantics, do nothing. class _ConfigFileParser(flags.ArgumentParser): """Parser for config files.""" def __init__(self, name, lock_config=True): self.name = name self._lock_config = lock_config def parse(self, path): """Loads a config module from `path` and returns the `get_config()` result. If a colon is present in `path`, everything to the right of the first colon is passed to `get_config` as an argument. This allows the structure of what is returned to be modified, which is useful when performing complex hyperparameter sweeps. Args: path: string, path pointing to the config file to execute. May also contain a config_string argument, e.g. be of the form "config.py:some_configuration". Returns: Result of calling `get_config` in the specified module. """ # This will be a 2 element list iff extra configuration args are present. split_path = path.split(':', 1) try: config_module = _LoadConfigModule('{}_config'.format(self.name), split_path[0]) config = config_module.get_config(*split_path[1:]) if config is None: logging.warning( '%s:get_config() returned None, did you forget a return statement?', path) except IOError as e: # Don't raise the error unless/until the config is actually accessed. config = _ErrorConfig(e) # Third party flags library catches TypeError and ValueError and rethrows, # removing useful information unless it is added here (b/63877430): except (TypeError, ValueError) as e: error_trace = traceback.format_exc() raise type(e)('Error whilst parsing config file:\n\n' + error_trace) if self._lock_config: _LockConfig(config) return config def flag_type(self): return 'config object' class _InlineConfigParser(flags.ArgumentParser): """Parser for a config defined inline (not from a file).""" def __init__(self, name, lock_config=True): self.name = name self._lock_config = lock_config def parse(self, config): if not isinstance(config, ml_collections.ConfigDict): raise TypeError('Overriding {} is not allowed.'.format(self.name)) if self._lock_config: _LockConfig(config) return config def flag_type(self): return 'config object' class _ConfigFlag(flags.Flag): """Flag definition for command-line overridable configs.""" def __init__(self, flag_values=FLAGS, **kwargs): # Parent constructor can already call .Parse, thus additional fields # have to be set here. self.flag_values = flag_values super(_ConfigFlag, self).__init__(**kwargs) def _GetOverrides(self, argv): """Parses the command line arguments for the overrides.""" overrides = [] config_index = self._FindConfigSpecified(argv) for i, arg in enumerate(argv): if re.match(r'-{{1,2}}(no)?{}\.'.format(self.name), arg): if config_index > 0 and i < config_index: raise FlagOrderError('Found {} in argv before a value for --{} ' 'was specified'.format(arg, self.name)) arg_name = arg.split('=', 1)[0] overrides.append(arg_name.split('.', 1)[1]) return overrides def _FindConfigSpecified(self, argv): """Finds element in argv specifying the value of the config flag. Args: argv: command line arguments as a list of strings. Returns: Index in argv if found and -1 otherwise. """ for i, arg in enumerate(argv): # '-(-)config' followed by '=' or at the end of the string. if re.match(r'^-{{1,2}}{}(=|$)'.format(self.name), arg) is not None: return i return -1 def _IsConfigSpecified(self, argv): """Returns `True` if the config file is specified on the command line.""" return self._FindConfigSpecified(argv) >= 0 def _set_default(self, default): if self._IsConfigSpecified(sys.argv): self.default = default else: super(_ConfigFlag, self)._set_default(default) # pytype: disable=attribute-error self.default_as_str = "'{}'".format(default) def _parse(self, argument): # Parse config config = super(_ConfigFlag, self)._parse(argument) # Get list or overrides overrides = self._GetOverrides(sys.argv) # Attach types definitions overrides_types = GetTypes(overrides, config) # Iterate over overridden fields and create valid parsers self._override_values = {} for field_path, field_type in zip(overrides, overrides_types): field_help = 'An override of {}\'s field {}'.format(self.name, field_path) field_name = '{}.{}'.format(self.name, field_path) if field_type in _FIELD_TYPE_TO_PARSER: parser = _ConfigFieldParser(_FIELD_TYPE_TO_PARSER[field_type], field_path, config, self._override_values) flags.DEFINE( parser, field_name, GetValue(field_path, config), field_help, flag_values=self.flag_values, serializer=_FIELD_TYPE_TO_SERIALIZER[field_type]) flag = self.flag_values._flags().get(field_name) # pylint: disable=protected-access flag.boolean = field_type is bool else: raise UnsupportedOperationError( "Type {} of field {} is not supported for overriding. " "Currently supported types are: {}. (Note that tuples should " "be passed as a string on the command line: flag='(a, b, c)', " "rather than flag=(a, b, c).)".format( field_type, field_name, _FIELD_TYPE_TO_PARSER.keys())) self._config_filename = argument return config @property def config_filename(self): """Returns a path to a config file. Typical usage example: `script.py`: ```python ... from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file( name='my_config', default='ml_collections/config_flags/tests/configdict_config.py', help_string='config file') ... FLAGS['my_config'].config_filename will output 'ml_collections/config_flags/tests/configdict_config.py' ``` Returns: A path to a config file. For a parameterised get_config, the config filename with the provided parameterisation is returned. Raises: UnparsedFlagError: if the flag has not been parsed. """ if not hasattr(self, '_config_filename'): raise UnparsedFlagError('The flag has not been parsed yet') return self._config_filename @property def override_values(self): """Returns a flat dictionary containing overridden values. Keys in the dictionary are dot-separated paths navigating to child items in the original configuration. For example, supppose that a `config` flag is defined and initialized to the following configuration: ```python { 'a': 1, 'nested': { 'b': 2 } } ``` and the user overrides both values using command-line flags: ``` --config.a=10 --config.nested.b=20 ``` Then `FLAGS['config'].override_values` will return: ```python { 'a': 10, 'nested.b': 20 } ``` The result can be passed to `ConfigDict.update_from_flattened_dict` to update the values in a configuration. Continuing with the example above: ```python import ml_collections config = ml_collections.ConfigDict{ 'a': 123, 'nested': { 'b': 456 } } config.update_from_flattened_dict(FLAGS['config'].override_values) print(config.a) # Prints `10`. print(config.nested.b) # Prints `20`. ``` Returns: Flat dictionary with overridden values. Raises: UnparsedFlagError: if the flag has not been parsed. """ if not hasattr(self, '_override_values'): raise UnparsedFlagError('The flag has not been parsed yet') return self._override_values def is_config_flag(flag): # pylint: disable=g-bad-name """Returns True iff `flag` is an instance of `_ConfigFlag`. External users of the library may need to check if a flag is of this type or not, particularly because ConfigFlags should be parsed before any other flags. This function allows that test to be done without making the whole class public. Args: flag: Flag object. Returns: True iff `isinstance(flag, _ConfigFlag)` is true. """ return isinstance(flag, _ConfigFlag) class _ConfigFieldParser(flags.ArgumentParser): """Parser with config update after parsing. This class-based wrapper creates a new object, which uses existing parser to do actual parsing and attaches a single callback to SetValue afterwards, which is used to update a predefined path in the config object. """ def __init__( self, parser: flags.ArgumentParser, path: str, config: ml_collections.ConfigDict, override_values: MutableMapping[str, Any]): """Creates new parser with callback, using existing one to perform parsing. Args: parser: ArgumentParser instance to wrap. path: Dot separated path in config to update with the result of parser.parse(...) config: Reference to the config object. override_values: Dictionary with override values. The 'parse' method will add the parsed value to this dictionary with key `path`. """ self._parser = parser self._path = path self._config = config self._override_values = override_values def __getattr__(self, attr): return getattr(self._parser, attr) def parse(self, argument): # pylint: disable=invalid-name value = self._parser.parse(argument) SetValue(self._path, self._config, value) self._override_values[self._path] = value return value def flag_type(self) -> str: return self._parser.flag_type() @property def syntactic_help(self) -> str: return self._parser.syntactic_help def _ExtractIndicesFromStep(step): """Separates identifier from indexes. Example: id[1][10] -> 'id', [1, 10].""" # Map returns an iterable in Python 3, thus the additional `list`. return step.split('[', 1)[0], list(map(int, re.findall(r'\[(\d+)\]', step))) def _AccessConfig(current, field, indices): """Returns member of the current config. The field in `current` specified by the `field` parameter is indexed using the values in the `indices` indices. Example: ```python current = {'first': [0], 'second': [[1]]} _AccessConfig(current, 'first', [0]) # Returns 0 _AccessConfig(current, 'second', [0, 0]) # Returns 1 ``` Args: current: Current config. field: Field to access (string). indices: List of indices. Returns: Member of the config. Raises: IndexError: if indices are invalid. KeyError: if the field is not found. """ if isinstance(field, string_types) and hasattr(current, field): current = getattr(current, field) elif hasattr(current, '__getitem__'): current = current[field] else: raise KeyError(field) for i, index in enumerate(indices): try: current = current[index] except TypeError: msg = 'Could not index config {}'.format(field) if i > 0: msg += '[{}]'.format(']['.join(map(str, indices[:i]))) raise IndexError(msg) except IndexError: msg = 'Index [{}] not found in config {}'.format( ']['.join(map(str, indices[i:])), field) if i > 0: msg += '[{}]'.format(']['.join(map(str, indices[:i]))) raise IndexError(msg) return current def _TakeStep(current, step): field, indices = _ExtractIndicesFromStep(step) return _AccessConfig(current, field, indices) def GetValue(path: str, config: ml_collections.ConfigDict): """Gets value of a single field.""" current = config for step in path.split('.'): current = _TakeStep(current, step) return current def GetType(path, config): """Gets type of field in config described by a dotted delimited path.""" steps = path.split('.') current = config for step in steps[:-1]: current = _TakeStep(current, step) # We need to check if there is list-type indexing in the last step. field, indices = _ExtractIndicesFromStep(steps[-1]) if indices: return type(_AccessConfig(current, field, indices)) else: # Check if current is a DM collection and hence has attribute get_type() is_dm_collection = isinstance(current, ml_collections.ConfigDict) or isinstance( current, ml_collections.FieldReference) if is_dm_collection: return current.get_type(steps[-1]) else: return type(_TakeStep(current, steps[-1])) def GetTypes(paths, config): """Gets types of fields in config described by dotted delimited paths.""" return [GetType(path, config) for path in paths] def SetValue(path: str, config: ml_collections.ConfigDict, value): """Sets value of a single field.""" current = config steps = path.split('.') if not steps or not path: raise ValueError('Path cannot be empty') for step in steps[:-1]: current = _TakeStep(current, step) field, indices = _ExtractIndicesFromStep(steps[-1]) if indices: current = _AccessConfig(current, field, indices[:-1]) try: current[indices[-1]] = value except: raise UnsupportedOperationError( 'Config does not have a setter for {}'.format(path)) else: if hasattr(current, '__setitem__') and field in current: current[field] = value elif hasattr(current, field): setattr(current, field, value) else: raise UnsupportedOperationError( 'Config does not have a setter for {}'.format(path)) ml_collections-0.1.1/ml_collections/config_flags/tuple_parser.py0000640000175000017500000000662514174507605024571 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Custom parser to override tuples in the config dict.""" import ast from absl import flags class TupleParser(flags.ArgumentParser): """Parser for tuple arguments. Custom flag parser for Tuple objects that is based on the existing parsers in `absl.flags`. This parser can be used to read in and override tuple arguments. It outputs a `tuple` object from an existing `tuple` or `str`. The ony requirement is that the overriding parameter should be a `tuple` as well. The overriding parameter can have a different number of elements of different type than the original. For a detailed list of what `str` arguments are supported for overriding, look at `ast.literal_eval` from the Python Standard Library. """ def parse(self, argument): """Returns a `tuple` representing the input `argument`. Args: argument: The argument to be parsed. Valid types are `tuple` and `str`. An empty `tuple` is returned for arguments `NoneType`. Returns: A `TupleType` representing the input argument as a `tuple`. Raises: `TypeError`: If the argument is not of type `tuple`, `str`, or `NoneType`. `ValueError`: If the string is not a well formed `tuple`. """ if isinstance(argument, tuple): return argument elif isinstance(argument, str): return _convert_str_to_tuple(argument) elif argument is None: return () else: msg = ('Could not parse argument {} of type ' '{} for element of type `tuple`.' ).format(argument, type(argument)) raise TypeError(msg) def flag_type(self): return 'tuple' def _convert_str_to_tuple(string): """Function to convert a Python `str` object to a `tuple`. Args: string: The `str` to be converted. Returns: A `tuple` version of the string. Raises: ValueError: If the string is not a well formed `tuple`. """ # literal_eval converts strings to int, tuple, list, float and dict, # booleans and None. It can also handle nested tuples. # It does not, however, handle elements of type set. try: value = ast.literal_eval(string) except ValueError: # A ValueError is raised by literal_eval if the string is not well # formed. Catch it and print out a more readable statement. msg = 'Argument {} does not evaluate to a `tuple` object.'.format(string) raise ValueError(msg) except SyntaxError: # The only other error that may be raised is a `SyntaxError` because # `literal_eval` calls the Python in-built `compile`. This error is # caused by parsing issues. msg = 'Error while parsing string: {}'.format(string) raise ValueError(msg) # Make sure we got a tuple. If not, its an error. if isinstance(value, tuple): return value else: raise ValueError('Expected a tuple argument, got {}'.format(type(value))) ml_collections-0.1.1/ml_collections/config_flags/tests/0000750000175000017500000000000014174510450022632 5ustar nileshnileshml_collections-0.1.1/ml_collections/config_flags/tests/typeerror_config.py0000640000175000017500000000223614174507605026600 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file that raises TypeError on import. When trying loading the configuration file as a flag, the flags library catches TypeError exceptions then recasts them as a IllegalFlagTypeError and rethrows (b/63877430). The rethrow does not include the stacktrace from the original exception, so we manually add the stracktrace in configflags.parse(). This is tested in `ConfigFlagTest.testTypeError` in `config_overriding_test.py`. """ def type_error_function(): raise TypeError('This is a TypeError.') def get_config(): return {'item': type_error_function()} ml_collections-0.1.1/ml_collections/config_flags/tests/mock_config.py0000640000175000017500000000260714174507605025500 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Dummy Config file.""" import copy class TestConfig(object): """Just a dummy config.""" def __init__(self): self.integer = 23 self.float = 2.34 self.string = 'james' self.bool = True self.dict = { 'integer': 1, 'float': 3.14, 'string': 'mark', 'bool': False, 'dict': { 'float': 5. }, 'list': [1, 2, [3]] } self.list = [1, 2, [3]] self.tuple = (1, 2, (3,)) self.tuple_with_spaces = (1, 2, (3,)) @property def readonly_field(self): return 42 def __repr__(self): return str(self.__dict__) def get_config(): config = TestConfig() config.object = TestConfig() config.object_reference = config.object config.object_copy = copy.deepcopy(config.object) return config ml_collections-0.1.1/ml_collections/config_flags/tests/config_overriding_test.py0000640000175000017500000010000714174507605027747 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Tests for ml_collection.config_flags.""" import copy import shlex import sys from absl import flags from absl.testing import absltest from absl.testing import flagsaver from absl.testing import parameterized import ml_collections from ml_collections.config_flags import config_flags from ml_collections.config_flags.tests import fieldreference_config from ml_collections.config_flags.tests import mock_config import six _CHECK_TYPES = tuple( list(six.integer_types) + list(six.string_types) + [float, bool]) _TEST_DIRECTORY = 'ml_collections/config_flags/tests' _TEST_CONFIG_FILE = '{}/mock_config.py'.format(_TEST_DIRECTORY) # Parameters to test that config loading and overriding works with both # one and two dashes. _DASH_PARAMETERS = ( ('WithTwoDashesAndEqual', '--test_config={}'.format(_TEST_CONFIG_FILE)), ('WithTwoDashes', '--test_config {}'.format(_TEST_CONFIG_FILE)), ('WithOneDashAndEqual', '-test_config={}'.format(_TEST_CONFIG_FILE)), ('WithOneDash', '-test_config {}'.format(_TEST_CONFIG_FILE))) _CONFIGDICT_CONFIG_FILE = '{}/configdict_config.py'.format(_TEST_DIRECTORY) _IOERROR_CONFIG_FILE = '{}/ioerror_config.py'.format(_TEST_DIRECTORY) _VALUEERROR_CONFIG_FILE = '{}/valueerror_config.py'.format(_TEST_DIRECTORY) _TYPEERROR_CONFIG_FILE = '{}/typeerror_config.py'.format(_TEST_DIRECTORY) _FIELDREFERENCE_CONFIG_FILE = '{}/fieldreference_config.py'.format( _TEST_DIRECTORY) _PARAMETERISED_CONFIG_FILE = '{}/parameterised_config.py'.format( _TEST_DIRECTORY) def _parse_flags(command, default=None, config=None, lock_config=True, required=False): """Parses arguments simulating sys.argv.""" if config is not None and default is not None: raise ValueError('If config is supplied a default should not be.') # Storing copy of the old sys.argv. old_argv = list(sys.argv) # Overwriting sys.argv, as sys has a global state it gets propagated. # The module shlex is useful here because it splits the input similar to # sys.argv. For instance, string arguments are not split by space. sys.argv = shlex.split(command) # Actual parsing. values = flags.FlagValues() if config is None: config_flags.DEFINE_config_file( 'test_config', default=default, flag_values=values, lock_config=lock_config) else: config_flags.DEFINE_config_dict( 'test_config', config=config, flag_values=values, lock_config=lock_config) if required: flags.mark_flag_as_required('test_config', flag_values=values) values(sys.argv) # Going back to original values. sys.argv = old_argv return values def _get_override_flags(overrides, override_format): return ' '.join([override_format.format(path, value) for path, value in six.iteritems(overrides)]) class _ConfigFlagTestCase(object): """Base class for tests with additional asserts for comparing configs.""" def assert_subset_configs(self, config1, config2): """Checks if all atrributes/values in config1 are present in config2.""" if config1 is None: return if hasattr(config1, '__dict__'): keys = [key for key in config1.__dict__ if not key.startswith('_')] get_val = getattr elif hasattr(config1, 'keys') and callable(config1.keys): keys = [key for key in config1.keys() if not key.startswith('_')] get_val = lambda haystack, needle: haystack[needle] else: # This should not fail as it simply means we cannot iterate deeper. return for attribute in keys: if isinstance(get_val(config1, attribute), _CHECK_TYPES): self.assertEqual(get_val(config1, attribute), get_val(config2, attribute)) else: # Try to go deeper with comparison self.assert_subset_configs(get_val(config1, attribute), get_val(config2, attribute)) def assert_equal_configs(self, config1, config2): """Checks if two configs are identical.""" self.assert_subset_configs(config1, config2) self.assert_subset_configs(config2, config1) # pylint: disable=arguments-out-of-order class ConfigFileFlagTest(_ConfigFlagTestCase, parameterized.TestCase): """Tests config flags library.""" @parameterized.named_parameters(*_DASH_PARAMETERS) def testLoading(self, config_flag): """Tests loading config from file.""" values = _parse_flags('./program {}'.format(config_flag), default='nonexisting.py') self.assertIn('test_config', values) self.assert_equal_configs(values.test_config, mock_config.get_config()) @parameterized.named_parameters(*_DASH_PARAMETERS) def testRequired(self, config_flag): """Tests making a config_file flag required.""" with self.assertRaises(flags.IllegalFlagValueError): _parse_flags('./program ', required=True) values = _parse_flags('./program {}'.format(config_flag), required=True) self.assertIn('test_config', values) def testDefaultLoading(self): """Tests loading config from file using default path.""" for required in [True, False]: values = _parse_flags( './program', default=_TEST_CONFIG_FILE, required=required) self.assertIn('test_config', values) self.assert_equal_configs(values.test_config, mock_config.get_config()) def testLoadingNonExistingConfigLoading(self): """Tests whether loading non existing file raises an error.""" nonexisting_file = 'nonexisting.py' # Test whether loading non existing files raises an Error with both # file loading formats i.e. with '--' and '-'. # The Error is not expected to be raised until the config dict actually # has one of its attributes accessed. values = _parse_flags( './program --test_config={}'.format(nonexisting_file), default=_TEST_CONFIG_FILE) with self.assertRaisesRegex(IOError, '.*{}.*'.format(nonexisting_file)): _ = values.test_config.a values = _parse_flags( './program -test_config {}'.format(nonexisting_file), default=_TEST_CONFIG_FILE) with self.assertRaisesRegex(IOError, '.*{}.*'.format(nonexisting_file)): _ = values.test_config.a values = _parse_flags( './program -test_config ""', default=_TEST_CONFIG_FILE) with self.assertRaisesRegex(IOError, 'empty string'): _ = values.test_config.a def testIOError(self): """Tests that IOErrors raised inside config files are reported correctly.""" values = _parse_flags('./program --test_config={}' .format(_IOERROR_CONFIG_FILE)) with self.assertRaisesRegex(IOError, 'This is an IOError'): _ = values.test_config.a def testValueError(self): """Tests that ValueErrors raised when parsing config files are passed up.""" allchars_regexp = r'[\s\S]*' error_regexp = allchars_regexp.join(['Error whilst parsing config file', 'in get_config', 'in value_error_function', 'This is a ValueError']) with self.assertRaisesRegex(flags.IllegalFlagValueError, error_regexp): _ = _parse_flags('./program --test_config={}' .format(_VALUEERROR_CONFIG_FILE)) def testTypeError(self): """Tests that TypeErrors raised when parsing config files are passed up.""" allchars_regexp = r'[\s\S]*' error_regexp = allchars_regexp.join(['Error whilst parsing config file', 'in get_config', 'in type_error_function', 'This is a TypeError']) with self.assertRaisesRegex(flags.IllegalFlagValueError, error_regexp): _ = _parse_flags('./program --test_config={}' .format(_TYPEERROR_CONFIG_FILE)) # Note: While testing the overriding of parameters, we explicitly set # '!r' in the format string for the value. This ensures the 'repr()' is # called on the argument which basically means that for string arguments, # the quotes (' ') are left intact when we format the string. @parameterized.named_parameters( ('TwoDashConfigAndOverride', '--test_config={}'.format(_TEST_CONFIG_FILE), '--test_config.{}={!r}'), ('TwoDashSpaceConfigAndOverride', '--test_config {}'.format(_TEST_CONFIG_FILE), '--test_config.{} {!r}'), ('OneDashConfigAndOverride', '-test_config {}'.format(_TEST_CONFIG_FILE), '-test_config.{} {!r}'), ('OneDashEqualConfigAndOverride', '-test_config={}'.format(_TEST_CONFIG_FILE), '-test_config.{}={!r}'), ('OneDashConfigAndTwoDashOverride', '-test_config {}'.format(_TEST_CONFIG_FILE), '--test_config.{}={!r}'), ('TwoDashConfigAndOneDashOverride', '--test_config={}'.format(_TEST_CONFIG_FILE), '-test_config.{} {!r}')) def testOverride(self, config_flag, override_format): """Tests overriding config values from command line.""" overrides = { 'integer': 1, 'float': -3, 'dict.float': 3, 'object.integer': 12, 'object.float': 123, 'object.string': 'tom', 'object.dict.integer': -2, 'object.dict.float': 3.15, 'object.dict.list[0]': 101, 'object.dict.list[2][0]': 103, 'object.list[0]': 101, 'object.list[2][0]': 103, 'object.tuple': '(1,2,(1,2))', 'object.tuple_with_spaces': '(1, 2, (1, 2))', 'object_reference.dict.string': 'marry', 'object.dict.dict.float': 123, 'object_copy.float': 111.111 } override_flags = _get_override_flags(overrides, override_format) values = _parse_flags('./program {} {}'.format(config_flag, override_flags)) test_config = mock_config.get_config() test_config.integer = overrides['integer'] test_config.float = overrides['float'] test_config.dict['float'] = overrides['dict.float'] test_config.object.integer = overrides['object.integer'] test_config.object.float = overrides['object.float'] test_config.object.string = overrides['object.string'] test_config.object.dict['integer'] = overrides['object.dict.integer'] test_config.object.dict['float'] = overrides['object.dict.float'] test_config.object.dict['list'][0] = overrides['object.dict.list[0]'] test_config.object.dict['list'][2][0] = overrides['object.dict.list[2][0]'] test_config.object.dict['list'][0] = overrides['object.list[0]'] test_config.object.dict['list'][2][0] = overrides['object.list[2][0]'] test_config.object.tuple = (1, 2, (1, 2)) test_config.object.tuple_with_spaces = (1, 2, (1, 2)) test_config.object_reference.dict['string'] = overrides[ 'object_reference.dict.string'] test_config.object.dict['dict']['float'] = overrides[ 'object.dict.dict.float'] test_config.object_copy.float = overrides['object_copy.float'] self.assert_equal_configs(values.test_config, test_config) def testOverrideBoolean(self): """Tests overriding boolean config values from command line.""" prefix = './program --test_config={}'.format(_TEST_CONFIG_FILE) # The default for dict.bool is False. values = _parse_flags('{} --test_config.dict.bool'.format(prefix)) self.assertTrue(values.test_config.dict['bool']) values = _parse_flags('{} --test_config.dict.bool=true'.format(prefix)) self.assertTrue(values.test_config.dict['bool']) # The default for object.bool is True. values = _parse_flags('{} --test_config.object.bool=false'.format(prefix)) self.assertFalse(values.test_config.object.bool) values = _parse_flags('{} --notest_config.object.bool'.format(prefix)) self.assertFalse(values.test_config.object.bool) def testFieldReferenceOverride(self): """Tests whether types of FieldReference fields are valid.""" overrides = {'ref_nodefault': 1, 'ref': 2} override_flags = _get_override_flags(overrides, '--test_config.{}={!r}') config_flag = '--test_config={}'.format(_FIELDREFERENCE_CONFIG_FILE) values = _parse_flags('./program {} {}'.format(config_flag, override_flags)) cfg = values.test_config self.assertEqual(cfg.ref_nodefault, overrides['ref_nodefault']) self.assertEqual(cfg.ref, overrides['ref']) @parameterized.named_parameters(*_DASH_PARAMETERS) def testSetNotExistingKey(self, config_flag): """Tests setting value of not existing key.""" with self.assertRaises(KeyError): _parse_flags('./program {} ' '--test_config.not_existing_key=1 '.format(config_flag)) @parameterized.named_parameters(*_DASH_PARAMETERS) def testSetReadOnlyField(self, config_flag): """Tests setting value of key which is read only.""" with self.assertRaises(AttributeError): _parse_flags('./program {} ' '--test_config.readonly_field=1 '.format(config_flag)) @parameterized.named_parameters(*_DASH_PARAMETERS) def testNotSupportedOperation(self, config_flag): """Tests setting value to not supported type.""" with self.assertRaises(config_flags.UnsupportedOperationError): _parse_flags('./program {} ' '--test_config.list=[1]'.format(config_flag)) def testReadingNonExistingKey(self): """Tests reading non existing key from config.""" test_config = mock_config.get_config() with self.assertRaises(config_flags.UnsupportedOperationError): config_flags.SetValue('dict.not_existing_key', test_config, 1) def testReadingSettingExistingKeyInDict(self): """Tests setting non existing key from dict inside config.""" test_config = mock_config.get_config() with self.assertRaises(KeyError): config_flags.SetValue('dict.not_existing_key.key', test_config, 1) def testEmptyKey(self): """Tests calling an empty key update.""" test_config = mock_config.get_config() with self.assertRaises(ValueError): config_flags.SetValue('', test_config, None) def testListExtraIndex(self): """Tries to index a non-indexable list element.""" test_config = mock_config.get_config() with self.assertRaises(IndexError): config_flags.GetValue('dict.list[0][0]', test_config) def testListOutOfRangeGet(self): """Tries to access out-of-range value in list.""" test_config = mock_config.get_config() with self.assertRaises(IndexError): config_flags.GetValue('dict.list[2][1]', test_config) def testListOutOfRangeSet(self): """Tries to override out-of-range value in list.""" test_config = mock_config.get_config() with self.assertRaises(config_flags.UnsupportedOperationError): config_flags.SetValue('dict.list[2][1]', test_config, -1) def testParserWrapping(self): """Tests callback based Parser wrapping.""" parser = flags.IntegerParser() test_config = mock_config.get_config() overrides = {} wrapped_parser = config_flags._ConfigFieldParser(parser, 'integer', test_config, overrides) wrapped_parser.parse('12321') self.assertEqual(test_config.integer, 12321) self.assertEqual(overrides, {'integer': 12321}) self.assertEqual(wrapped_parser.flag_type(), parser.flag_type()) self.assertEqual(wrapped_parser.syntactic_help, parser.syntactic_help) self.assertEqual(wrapped_parser.convert('3'), parser.convert('3')) def testTypes(self): """Tests whether various types of objects are valid.""" parser = config_flags._ConfigFileParser('test_config') self.assertEqual(parser.flag_type(), 'config object') test_config = mock_config.get_config() paths = ( 'float', 'integer', 'string', 'bool', 'dict', 'dict.float', 'dict.list', 'list', 'list[0]', 'object.float', 'object.integer', 'object.string', 'object.bool', 'object.dict', 'object.dict.float', 'object.dict.list', 'object.list', 'object.list[0]', 'object.tuple', 'object_reference.float', 'object_reference.integer', 'object_reference.string', 'object_reference.bool', 'object_reference.dict', 'object_reference.dict.float', 'object_copy.float', 'object_copy.integer', 'object_copy.string', 'object_copy.bool', 'object_copy.dict', 'object_copy.dict.float' ) paths_types = [ float, int, str, bool, dict, float, list, list, int, float, int, str, bool, dict, float, list, list, int, tuple, float, int, str, bool, dict, float, float, int, str, bool, dict, float, ] config_types = config_flags.GetTypes(paths, test_config) self.assertEqual(paths_types, config_types) def testFieldReferenceTypes(self): """Tests whether types of FieldReference fields are valid.""" test_config = fieldreference_config.get_config() paths = ['ref_nodefault', 'ref'] paths_types = [int, int] config_types = config_flags.GetTypes(paths, test_config) self.assertEqual(paths_types, config_types) @parameterized.named_parameters( ('WithTwoDashesAndEqual', '--test_config=config.py'), ('WithTwoDashes', '--test_config'), ('WithOneDashAndEqual', '-test_config=config.py'), ('WithOneDash', '-test_config')) def testConfigSpecified(self, config_argument): """Tests whether config is specified on the command line.""" config_flag = config_flags._ConfigFlag( parser=flags.ArgumentParser(), serializer=None, name='test_config', default='defaultconfig.py', help_string='' ) self.assertTrue(config_flag._IsConfigSpecified([config_argument])) self.assertFalse(config_flag._IsConfigSpecified([''])) def testFindConfigSpecified(self): """Tests whether config is specified on the command line.""" config_flag = config_flags._ConfigFlag( parser=flags.ArgumentParser(), serializer=None, name='test_config', default='defaultconfig.py', help_string='' ) self.assertEqual(config_flag._FindConfigSpecified(['']), -1) argv_length = 20 for i in range(argv_length): # Generate list of '--test_config.i=0' args. argv = ['--test_config.{}=0'.format(arg) for arg in range(argv_length)] self.assertEqual(config_flag._FindConfigSpecified(argv), -1) # Override i-th arg with something specifying the value of 'test_config'. # After doing this, _FindConfigSpecified should return the value of i. argv[i] = '--test_config' self.assertEqual(config_flag._FindConfigSpecified(argv), i) argv[i] = '--test_config=config.py' self.assertEqual(config_flag._FindConfigSpecified(argv), i) argv[i] = '-test_config' self.assertEqual(config_flag._FindConfigSpecified(argv), i) argv[i] = '-test_config=config.py' self.assertEqual(config_flag._FindConfigSpecified(argv), i) def testLoadingLockedConfigDict(self): """Tests loading ConfigDict instance and that it is locked.""" config_flag = '--test_config={}'.format(_CONFIGDICT_CONFIG_FILE) values = _parse_flags('./program {}'.format(config_flag), lock_config=True) self.assertTrue(values.test_config.is_locked) self.assertTrue(values.test_config.nested_configdict.is_locked) values = _parse_flags('./program {}'.format(config_flag), lock_config=False) self.assertFalse(values.test_config.is_locked) self.assertFalse(values.test_config.nested_configdict.is_locked) @parameterized.named_parameters( ('WithTwoDashesAndEqual', '--test_config={}'.format(_TEST_DIRECTORY)), ('WithTwoDashes', '--test_config {}'.format(_TEST_DIRECTORY)), ('WithOneDashAndEqual', '-test_config={}'.format(_TEST_DIRECTORY)), ('WithOneDash', '-test_config {}'.format(_TEST_DIRECTORY))) def testPriorityOfFieldLookup(self, config_flag): """Tests whether attributes have higher priority than key-based lookup.""" values = _parse_flags('./program {}/mini_config.py'.format(config_flag), lock_config=False) self.assertTrue( config_flags.GetValue('entry_with_collision', values.test_config)) @parameterized.named_parameters( ('TypeAEnabled', ':type_a', {'thing_a': 23, 'thing_b': 42}, ['thing_c']), ('TypeASpecify', ':type_a --test_config.thing_a=24', {'thing_a': 24}, []), ('TypeBEnabled', ':type_b', {'thing_a': 19, 'thing_c': 65}, ['thing_b'])) def testExtraConfigString(self, flag_override, should_exist, should_error): """Tests with the config_string argument is used properly.""" values = _parse_flags( './program --test_config={}/parameterised_config.py{}'.format( _TEST_DIRECTORY, flag_override)) # Ensure that the values exist in the ConfigDict, with desired values. for subfield_name, expected_value in six.iteritems(should_exist): self.assertEqual(values.test_config[subfield_name], expected_value) # Ensure the values which should not be part of the ConfigDict are really # not there. for should_error_name in should_error: with self.assertRaisesRegex(KeyError, 'Did you mean'): _ = values.test_config[should_error_name] def testExtraConfigInvalidFlag(self): with self.assertRaisesRegex(AttributeError, 'not_valid_item'): _parse_flags( ('./program --test_config={}/parameterised_config.py:type_a ' '--test_config.not_valid_item=42').format(_TEST_DIRECTORY)) def testOverridingConfigDict(self): """Tests overriding of ConfigDict fields.""" config_flag = '--test_config={}'.format(_CONFIGDICT_CONFIG_FILE) overrides = { 'integer': 2, 'reference': 2, 'list[0]': 5, 'nested_list[0][0]': 5, 'nested_configdict.integer': 5, 'unusable_config.dummy_attribute': 5 } override_flags = _get_override_flags(overrides, '--test_config.{}={}') values = _parse_flags('./program {} {}'.format(config_flag, override_flags)) self.assertEqual(values.test_config.integer, overrides['integer']) self.assertEqual(values.test_config.reference, overrides['reference']) self.assertEqual(values.test_config.list[0], overrides['list[0]']) self.assertEqual(values.test_config.nested_list[0][0], overrides['nested_list[0][0]']) self.assertEqual(values.test_config.nested_configdict.integer, overrides['nested_configdict.integer']) self.assertEqual(values.test_config.unusable_config.dummy_attribute, overrides['unusable_config.dummy_attribute']) # Attribute error. overrides = {'nonexistent': 'value'} with self.assertRaises(AttributeError): override_flags = _get_override_flags(overrides, '--test_config.{}={}') _parse_flags('./program {} {}'.format(config_flag, override_flags)) # "Did you mean" messages. overrides = {'integre': 2} with self.assertRaisesRegex(AttributeError, 'Did you.*integer.*'): override_flags = _get_override_flags(overrides, '--test_config.{}={}') _parse_flags('./program {} {}'.format(config_flag, override_flags)) overrides = {'referecne': 2} with self.assertRaisesRegex(AttributeError, 'Did you.*reference.*'): override_flags = _get_override_flags(overrides, '--test_config.{}={}') _parse_flags('./program {} {}'.format(config_flag, override_flags)) # This test adds new flags, so use FlagSaver to make it hermetic. @flagsaver.flagsaver def testIsConfigFile(self): config_flags.DEFINE_config_file('is_a_config_flag') flags.DEFINE_integer('not_a_config_flag', -1, '') self.assertTrue( config_flags.is_config_flag(flags.FLAGS['is_a_config_flag'])) self.assertFalse( config_flags.is_config_flag(flags.FLAGS['not_a_config_flag'])) # This test adds new flags, so use FlagSaver to make it hermetic. @flagsaver.flagsaver def testModuleName(self): config_flags.DEFINE_config_file('flag') argv_0 = './program' _parse_flags(argv_0) self.assertIn(flags.FLAGS['flag'], flags.FLAGS.flags_by_module_dict()[argv_0]) def testFlagOrder(self): with self.assertRaisesWithLiteralMatch( config_flags.FlagOrderError, ('Found --test_config.int=1 in argv before a value for --test_config ' 'was specified')): _parse_flags('./program --test_config.int=1 ' '--test_config={}'.format(_TEST_CONFIG_FILE)) @flagsaver.flagsaver def testOverrideValues(self): config_flags.DEFINE_config_file('config') with self.assertRaisesWithLiteralMatch(config_flags.UnparsedFlagError, 'The flag has not been parsed yet'): flags.FLAGS['config'].override_values # pylint: disable=pointless-statement original_float = -1.0 original_dictfloat = -2.0 config = ml_collections.ConfigDict({ 'integer': -1, 'float': original_float, 'dict': { 'float': original_dictfloat } }) integer_override = 0 dictfloat_override = 1.1 values = _parse_flags('./program --test_config={} --test_config.integer={} ' '--test_config.dict.float={}'.format( _TEST_CONFIG_FILE, integer_override, dictfloat_override)) config.update_from_flattened_dict( config_flags.get_override_values(values['test_config'])) self.assertEqual(config['integer'], integer_override) self.assertEqual(config['float'], original_float) self.assertEqual(config['dict']['float'], dictfloat_override) @parameterized.named_parameters( ('ConfigFile1', _TEST_CONFIG_FILE), ('ConfigFile2', _CONFIGDICT_CONFIG_FILE), ('ParameterisedConfigFile', _PARAMETERISED_CONFIG_FILE + ':type_a'), ) def testConfigPath(self, config_file): """Test access to saved config file path.""" values = _parse_flags('./program --test_config={}'.format(config_file)) self.assertEqual(config_flags.get_config_filename(values['test_config']), config_file) def _simple_config(): config = ml_collections.ConfigDict() config.foo = 3 return config class ConfigDictFlagTest(_ConfigFlagTestCase, absltest.TestCase): """Tests DEFINE_config_dict. DEFINE_config_dict reuses a lot of code in DEFINE_config_file so the tests here are mostly sanity checks. """ def testBasicUsage(self): values = _parse_flags('./program', config=_simple_config()) self.assertIn('test_config', values) self.assert_equal_configs(_simple_config(), values.test_config) def testChangingLockedConfigRaisesAnError(self): values = _parse_flags( './program', config=_simple_config(), lock_config=True) with self.assertRaisesRegex(AttributeError, 'config is locked'): values.test_config.new_foo = 20 def testChangingUnlockedConfig(self): values = _parse_flags( './program', config=_simple_config(), lock_config=False) values.test_config.new_foo = 20 def testNonConfigDictAsConfig(self): non_config_dict = dict(a=1, b=2) with self.assertRaisesRegex(TypeError, 'should be a ConfigDict'): _parse_flags('./program', config=non_config_dict) def testOverridingAttribute(self): new_foo = 10 values = _parse_flags( './program --test_config.foo={}'.format(new_foo), config=_simple_config()) self.assertNotEqual(new_foo, _simple_config().foo) self.assertEqual(new_foo, values.test_config.foo) def testOverridingMainConfigFlagRaisesAnError(self): with self.assertRaisesRegex(flags.IllegalFlagValueError, 'Overriding test_config is not allowed'): _parse_flags('./program --test_config=bad_input', config=_simple_config()) def testOverridesSerialize(self): all_types_config = ml_collections.ConfigDict() all_types_config.type_bool = False all_types_config.type_bytes = b'bytes' all_types_config.type_float = 1.0 all_types_config.type_int = 1 all_types_config.type_str = 'str' all_types_config.type_ustr = u'ustr' all_types_config.type_tuple = (False, b'bbytes', 1.0, 1, 'str', u'ustr',) # Change values via the command line: command_line = ( './program' ' --test_config.type_bool=True' ' --test_config.type_float=10' ' --test_config.type_int=10' ' --test_config.type_str=str_commandline' ' --test_config.type_tuple="(\'tuple_str\', 10)"' ) if six.PY3: # Not supported in py2; type_bytes is never supported. command_line += ' --test_config.type_ustr=ustr_commandline' values = _parse_flags(command_line, config=copy.copy(all_types_config)) # Check we get the expected values (ie the ones defined above not the ones # defined in the config itself) get_parser = lambda name: values._flags()[name].parser.parse get_serializer = lambda name: values._flags()[name].serializer.serialize serialize_parse = ( lambda name, value: get_parser(name)(get_serializer(name)(value))) with self.subTest('bool'): self.assertNotEqual(values.test_config.type_bool, all_types_config['type_bool']) self.assertEqual(values.test_config.type_bool, True) self.assertEqual(values.test_config.type_bool, serialize_parse('test_config.type_bool', values.test_config.type_bool)) with self.subTest('float'): self.assertNotEqual(values.test_config.type_float, all_types_config['type_float']) self.assertEqual(values.test_config.type_float, 10.) self.assertEqual(values.test_config.type_float, serialize_parse('test_config.type_float', values.test_config.type_float)) with self.subTest('int'): self.assertNotEqual(values.test_config.type_int, all_types_config['type_int']) self.assertEqual(values.test_config.type_int, 10) self.assertEqual(values.test_config.type_int, serialize_parse('test_config.type_int', values.test_config.type_int)) with self.subTest('str'): self.assertNotEqual(values.test_config.type_str, all_types_config['type_str']) self.assertEqual(values.test_config.type_str, 'str_commandline') self.assertEqual(values.test_config.type_str, serialize_parse('test_config.type_str', values.test_config.type_str)) with self.subTest('ustr'): if six.PY3: self.assertNotEqual(values.test_config.type_ustr, all_types_config['type_ustr']) self.assertEqual(values.test_config.type_ustr, u'ustr_commandline') self.assertEqual(values.test_config.type_ustr, serialize_parse('test_config.type_ustr', values.test_config.type_ustr)) with self.subTest('tuple'): self.assertNotEqual(values.test_config.type_tuple, all_types_config['type_tuple']) self.assertEqual(values.test_config.type_tuple, ('tuple_str', 10)) self.assertEqual(values.test_config.type_tuple, serialize_parse('test_config.type_tuple', values.test_config.type_tuple)) def main(): absltest.main() if __name__ == '__main__': main() ml_collections-0.1.1/ml_collections/config_flags/tests/mini_config.py0000640000175000017500000000205314174507605025476 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Dummy Config file.""" class MiniConfig(object): """Just a dummy config.""" def __init__(self): self.dict = {} self.field = False def __getitem__(self, key): return self.dict[key] def __contains__(self, key): return key in self.dict def __setitem__(self, key, value): self.dict[key] = value def get_config(): cfg = MiniConfig() cfg['entry_with_collision'] = False cfg.entry_with_collision = True return cfg ml_collections-0.1.1/ml_collections/config_flags/tests/valueerror_config.py0000640000175000017500000000224614174507605026734 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file that raises ValueError on import. When trying loading the configuration file as a flag, the flags library catches ValueError exceptions then recasts them as a IllegalFlagValueError and rethrows (b/63877430). The rethrow does not include the stacktrace from the original exception, so we manually add the stracktrace in configflags.parse(). This is tested in `ConfigFlagTest.testValueError` in `config_overriding_test.py`. """ def value_error_function(): raise ValueError('This is a ValueError.') def get_config(): return {'item': value_error_function()} ml_collections-0.1.1/ml_collections/config_flags/tests/configdict_config.py0000640000175000017500000000635014174507605026657 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """ConfigDict config file.""" import ml_collections class UnusableConfig(object): """Test against code assuming the semantics of attributes (such as `lock`). This class is to test that the flags implementation does not assume the semantics of attributes. This is to avoid code such as: ```python if hasattr(obj, lock): obj.lock() ``` which will fail if `obj` has an attribute `lock` that does not behave in the way we expect. This class only has unusable attributes. There are two exceptions for which this class behaves normally: * Python's special functions which start and end with a double underscore. * `dummy_attribute`, an attribute used to test the class. For other attributes, both `hasttr(obj, attr)` and `callable(obj, attr)` will return True. Calling `obj.attr` will return a function which takes no arguments and raises an AttributeError when called. For example, the `lock` example above will raise an AttributeError. The only valid action on attributes is assignment, e.g. ```python obj = UnusableConfig() obj.attr = 1 ``` In which case the attribute will keep its assigned value and become usable. """ def __init__(self): self._dummy_attribute = 1 def __getattr__(self, attribute): """Get an arbitrary attribute. Returns a function which takes no arguments and raises an AttributeError, except for Python special functions in which case an AttributeError is directly raised. Args: attribute: A string representing the attribute's name. Returns: A function which raises an AttributeError when called. Raises: AttributeError: when the attribute is a Python special function starting and ending with a double underscore. """ if attribute.startswith("__") and attribute.endswith("__"): raise AttributeError("UnusableConfig does not contain entry {}.". format(attribute)) def raise_attribute_error_fun(): raise AttributeError( "{} is not a usable attribute of UnusableConfig".format( attribute)) return raise_attribute_error_fun @property def dummy_attribute(self): return self._dummy_attribute @dummy_attribute.setter def dummy_attribute(self, value): self._dummy_attribute = value def get_config(): """Returns a ConfigDict. Used for tests.""" cfg = ml_collections.ConfigDict() cfg.integer = 1 cfg.reference = ml_collections.FieldReference(1) cfg.list = [1, 2, 3] cfg.nested_list = [[1, 2, 3]] cfg.nested_configdict = ml_collections.ConfigDict() cfg.nested_configdict.integer = 1 cfg.unusable_config = UnusableConfig() return cfg ml_collections-0.1.1/ml_collections/config_flags/tests/ioerror_config.py0000640000175000017500000000212214174507605026220 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file that raises IOError on import. The flags library tries to load configuration files in a few different ways. For this it relies on catching IOError exceptions of the type "File not found" and ignoring them to continue trying with a different loading method. But we need to ensure that other types of IOError exceptions are propagated correctly (b/63165566). This is tested in `ConfigFlagTest.testIOError` in `config_overriding_test.py`. """ raise IOError('This is an IOError.') ml_collections-0.1.1/ml_collections/config_flags/tests/parameterised_config.py0000640000175000017500000000204514174507605027370 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file where `get_config` takes a string argument.""" import ml_collections def get_config(config_string): """A config which takes an extra string argument.""" possible_configs = { 'type_a': ml_collections.ConfigDict({ 'thing_a': 23, 'thing_b': 42, }), 'type_b': ml_collections.ConfigDict({ 'thing_a': 19, 'thing_c': 65, }), } return possible_configs[config_string] ml_collections-0.1.1/ml_collections/config_flags/tests/dataclass_overriding_test.py0000640000175000017500000000434614174507605030452 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for config_flags used in conjunction with DEFINE_config_dataclass.""" import shlex import sys from typing import Mapping, Optional, Sequence from absl import flags from absl.testing import absltest import dataclasses from ml_collections import config_flags ##### # Simple dummy configuration classes. @dataclasses.dataclass class MyModelConfig: foo: int bar: Sequence[str] baz: Optional[Mapping[str, str]] = None @dataclasses.dataclass class MyConfig: my_model: MyModelConfig baseline_model: MyModelConfig _CONFIG = MyConfig( my_model=MyModelConfig( foo=3, bar=['a', 'b'], baz={'foo': 'bar'}, ), baseline_model=MyModelConfig( foo=55, bar=['c', 'd'], ), ) # Define the flag. _CONFIG_FLAG = config_flags.DEFINE_config_dataclass('config', _CONFIG) class TypedConfigFlagsTest(absltest.TestCase): def test_instance(self): config = _CONFIG_FLAG.value self.assertIsInstance(config, MyConfig) self.assertEqual(config.my_model, _CONFIG.my_model) self.assertEqual(_CONFIG, config) def test_flag_overrides(self): # Set up some flag overrides. old_argv = list(sys.argv) sys.argv = shlex.split( './program foo.py --test_config.baseline_model.foo=99') flag_values = flags.FlagValues() # Define a config dataclass flag. test_config = config_flags.DEFINE_config_dataclass( 'test_config', _CONFIG, flag_values=flag_values) # Inject the flag overrides. flag_values(sys.argv) sys.argv = old_argv # Did the value get overridden? self.assertEqual(test_config.value.baseline_model.foo, 99) if __name__ == '__main__': absltest.main() ml_collections-0.1.1/ml_collections/config_flags/tests/fieldreference_config.py0000640000175000017500000000157614174507605027515 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file with field references.""" import ml_collections from ml_collections.config_dict import config_dict def get_config(): cfg = ml_collections.ConfigDict() cfg.ref = ml_collections.FieldReference(123) cfg.ref_nodefault = config_dict.placeholder(int) return cfg ml_collections-0.1.1/ml_collections/config_flags/__init__.py0000640000175000017500000000200014174507605023602 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config flags module.""" from .config_flags import DEFINE_config_dataclass from .config_flags import DEFINE_config_dict from .config_flags import DEFINE_config_file from .config_flags import get_config_filename from .config_flags import get_override_values __all__ = ( "DEFINE_config_dataclass", "DEFINE_config_dict", "DEFINE_config_file", "get_config_filename", "get_override_values", ) ml_collections-0.1.1/ml_collections/config_flags/examples/0000750000175000017500000000000014174510450023306 5ustar nileshnileshml_collections-0.1.1/ml_collections/config_flags/examples/define_config_dataclass_basic.py0000640000175000017500000000247014174507605031633 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 r"""Example of basic DEFINE_config_dataclass usage. To run this example: python define_config_dataclass_basic.py -- --my_config.field1=8 \ --my_config.nested.field=2.1 --my_config.tuple='(1, 2, (1, 2))' """ from typing import Any, Mapping, Sequence from absl import app import dataclasses from ml_collections.config_flags import config_flags @dataclasses.dataclass class MyConfig: field1: int field2: str nested: Mapping[str, Any] tuple: Sequence[int] config = MyConfig( field1=1, field2='tom', nested={'field': 2.23}, tuple=(1, 2, 3), ) _CONFIG = config_flags.DEFINE_config_dataclass('my_config', config) def main(_): print(_CONFIG.value) if __name__ == '__main__': app.run(main) ml_collections-0.1.1/ml_collections/config_flags/examples/config.py0000640000175000017500000000163014174507605025136 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Defines a method which returns an instance of ConfigDict.""" import ml_collections def get_config(): config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) return config ml_collections-0.1.1/ml_collections/config_flags/examples/define_config_dict_basic.py0000640000175000017500000000234314174507605030616 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 r"""Example of basic DEFINE_config_dict usage. To run this example: python define_config_dict_basic.py -- --my_config_dict.field1=8 \ --my_config_dict.nested.field=2.1 --my_config_dict.tuple='(1, 2, (1, 2))' """ from absl import app import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) _CONFIG = config_flags.DEFINE_config_dict('my_config_dict', config) def main(_): print(_CONFIG.value) if __name__ == '__main__': app.run(main) ml_collections-0.1.1/ml_collections/config_flags/examples/examples_test.py0000640000175000017500000000303214174507605026544 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for config_flags examples. Ensures that from define_config_dict_basic, define_config_file_basic run successfully. """ from absl import flags from absl.testing import absltest from absl.testing import flagsaver from ml_collections.config_flags.examples import define_config_dict_basic from ml_collections.config_flags.examples import define_config_file_basic FLAGS = flags.FLAGS class ConfigDictExamplesTest(absltest.TestCase): def test_define_config_dict_basic(self): define_config_dict_basic.main([]) @flagsaver.flagsaver def test_define_config_file_basic(self): FLAGS.my_config = 'ml_collections/config_flags/examples/config.py' define_config_file_basic.main([]) @flagsaver.flagsaver def test_define_config_file_parameterised(self): FLAGS.my_config = 'ml_collections/config_flags/examples/parameterised_config.py:linear' define_config_file_basic.main([]) if __name__ == '__main__': absltest.main() ml_collections-0.1.1/ml_collections/config_flags/examples/parameterised_config.py0000640000175000017500000000250414174507605030044 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Defines a parameterized method which returns a config depending on input.""" import ml_collections def get_config(config_string): """Return an instance of ConfigDict depending on `config_string`.""" possible_structures = { 'linear': ml_collections.ConfigDict({ 'model_constructor': 'snt.Linear', 'model_config': ml_collections.ConfigDict({ 'output_size': 42, }) }), 'lstm': ml_collections.ConfigDict({ 'model_constructor': 'snt.LSTM', 'model_config': ml_collections.ConfigDict({ 'hidden_size': 108, }) }) } return possible_structures[config_string] ml_collections-0.1.1/ml_collections/config_flags/examples/define_config_file_basic.py0000640000175000017500000000257314174507605030617 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 # pylint: disable=line-too-long r"""Example of basic DEFINE_flag_dict usage. To run this example with basic config file: python define_config_dict_basic.py -- \ --my_config=ml_collections/config_flags/examples/config.py \ --my_config.field1=8 --my_config.nested.field=2.1 \ --my_config.tuple='(1, 2, (1, 2))' To run this example with parameterised config file: python define_config_dict_basic.py -- \ --my_config=ml_collections/config_flags/examples/parameterised_config.py:linear \ --my_config.model_config.output_size=256' """ # pylint: enable=line-too-long from absl import app from ml_collections.config_flags import config_flags _CONFIG = config_flags.DEFINE_config_file('my_config') def main(_): print(_CONFIG.value) if __name__ == '__main__': app.run(main) ml_collections-0.1.1/ml_collections/__init__.py0000640000175000017500000000162714174507605021177 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """ML Collections is a library of Python collections designed for ML usecases.""" from ml_collections.config_dict import ConfigDict from ml_collections.config_dict import FieldReference from ml_collections.config_dict import FrozenConfigDict __all__ = ("ConfigDict", "FieldReference", "FrozenConfigDict") ml_collections-0.1.1/docs/0000750000175000017500000000000014174510450014771 5ustar nileshnileshml_collections-0.1.1/docs/conf.py0000640000175000017500000000546714174507605016315 0ustar nileshnilesh# Copyright 2021 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # # import os # import sys # sys.path.insert(0, os.path.abspath('.')) import os import sys sys.path.insert(0, os.path.abspath('..')) # -- Project information ----------------------------------------------------- project = 'ml_collections' copyright = '2020, The ML Collection Authors' author = 'The ML Collection Authors' # The full version, including alpha/beta/rc tags release = '0.1.0' # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.intersphinx', 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', 'nbsphinx', 'recommonmark', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] autosummary_generate = True master_doc = 'index' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'sphinx_rtd_theme' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static']