././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/0000755162426002575230000000000000000000000017474 5ustar00mohitreddyprimarygroup00000000000000././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/PKG-INFO0000644162426002575230000005523100000000000020577 0ustar00mohitreddyprimarygroup00000000000000Metadata-Version: 2.1 Name: ml_collections Version: 0.1.0 Summary: ML Collections is a library of Python collections designed for ML usecases. Home-page: https://github.com/google/ml_collections Author: ML Collections Authors Author-email: ml-collections@google.com License: Apache 2.0 Description: # ML Collections ML Collections is a library of Python Collections designed for ML use cases. ## ConfigDict The two classes called `ConfigDict` and `FrozenConfigDict` are "dict-like" data structures with dot access to nested elements. Together, they are supposed to be used as a main way of expressing configurations of experiments and models. This document describes example usage of `ConfigDict`, `FrozenConfigDict`, `FieldReference`. ### Features * Dot-based access to fields. * Locking mechanism to prevent spelling mistakes. * Lazy computation. * FrozenConfigDict() class which is immutable and hashable. * Type safety. * "Did you mean" functionality. * Human readable printing (with valid references and cycles), using valid YAML format. * Fields can be passed as keyword arguments using the `**` operator. * There are two exceptions to the strong type-safety of the ConfigDict. `int` values can be passed in to fields of type `float`. In such a case, the value is type-converted to a `float` before being stored. Similarly, all string types (including Unicode strings) can be stored in fields of type `str` or `unicode`. ### Basic Usage ```python import ml_collections cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg.integer_field) # Prints 123. print(cfg['integer_field']) # Prints 123 as well. try: cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. except TypeError as e: print(e) cfg.float_field = 12 # Works: `Int` types can be assigned to `Float`. cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. print(cfg) ``` ### FrozenConfigDict A `FrozenConfigDict`is an immutable, hashable type of `ConfigDict`: ```python import ml_collections initial_dictionary = { 'int': 1, 'list': [1, 2], 'tuple': (1, 2, 3), 'set': {1, 2, 3, 4}, 'dict_tuple_list': {'tuple_list': ([1, 2], 3)} } cfg = ml_collections.ConfigDict(initial_dictionary) frozen_dict = ml_collections.FrozenConfigDict(initial_dictionary) print(frozen_dict.tuple) # Prints tuple (1, 2, 3) print(frozen_dict.list) # Prints tuple (1, 2) print(frozen_dict.set) # Prints frozenset {1, 2, 3, 4} print(frozen_dict.dict_tuple_list.tuple_list[0]) # Prints tuple (1, 2) frozen_cfg = ml_collections.FrozenConfigDict(cfg) print(frozen_cfg == frozen_dict) # True print(hash(frozen_cfg) == hash(frozen_dict)) # True try: frozen_dict.int = 2 # Raises TypeError as FrozenConfigDict is immutable. except AttributeError as e: print(e) # Converting between `FrozenConfigDict` and `ConfigDict`: thawed_frozen_cfg = ml_collections.ConfigDict(frozen_dict) print(thawed_frozen_cfg == cfg) # True frozen_cfg_to_cfg = frozen_dict.as_configdict() print(frozen_cfg_to_cfg == cfg) # True ``` ### FieldReferences and placeholders A `FieldReference` is useful for having multiple fields use the same value. It can also be used for [lazy computation](#lazy-computation). You can use `placeholder()` as a shortcut to create a `FieldReference` (field) with a `None` default value. This is useful if a program uses optional configuration fields. ```python import ml_collections from ml_collections.config_dict import config_dict placeholder = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict() cfg.placeholder = placeholder cfg.optional = config_dict.placeholder(int) cfg.nested = ml_collections.ConfigDict() cfg.nested.placeholder = placeholder try: cfg.optional = 'tom' # Raises Type error as this field is an integer. except TypeError as e: print(e) cfg.optional = 1555 # Works fine. cfg.placeholder = 1 # Changes the value of both placeholder and # nested.placeholder fields. print(cfg) ``` Note that the indirection provided by `FieldReference`s will be lost if accessed through a `ConfigDict`. ```python import ml_collections placeholder = ml_collections.FieldReference(0) cfg.field1 = placeholder cfg.field2 = placeholder # This field will be tied to cfg.field1. cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. ``` ### Lazy computation Using a `FieldReference` in a standard operation (addition, subtraction, multiplication, etc...) will return another `FieldReference` that points to the original's value. You can use `FieldReference.get()` to execute the operations and get the reference's computed value, and `FieldReference.set()` to change the original reference's value. ```python import ml_collections ref = ml_collections.FieldReference(1) print(ref.get()) # Prints 1 add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 # Addition is lazily computed for FieldReferences so changing ref will change # the value that is used to compute add_ten. ref.set(5) print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 ``` If a `FieldReference` has `None` as its original value, or any operation has an argument of `None`, then the lazy computation will evaluate to `None`. We can also use fields in a `ConfigDict` in lazy computation. In this case a field will only be lazily evaluated if `ConfigDict.get_ref()` is used to get it. ```python import ml_collections config = ml_collections.ConfigDict() config.reference_field = ml_collections.FieldReference(1) config.integer_field = 2 config.float_field = 2.5 # No lazy evaluatuations because we didn't use get_ref() config.no_lazy = config.integer_field * config.float_field # This will lazily evaluate ONLY config.integer_field config.lazy_integer = config.get_ref('integer_field') * config.float_field # This will lazily evaluate ONLY config.float_field config.lazy_float = config.integer_field * config.get_ref('float_field') # This will lazily evaluate BOTH config.integer_field and config.float_Field config.lazy_both = (config.get_ref('integer_field') * config.get_ref('float_field')) config.integer_field = 3 print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value print(config.lazy_integer) # Prints 7.5 config.float_field = 3.5 print(config.lazy_float) # Prints 7.0 print(config.lazy_both) # Prints 10.5 ``` #### Changing lazily computed values Lazily computed values in a ConfigDict can be overridden in the same way as regular values. The reference to the `FieldReference` used for the lazy computation will be lost and all computations downstream in the reference graph will use the new value. ```python import ml_collections config = ml_collections.ConfigDict() config.reference = 1 config.reference_0 = config.get_ref('reference') + 10 config.reference_1 = config.get_ref('reference') + 20 config.reference_1_0 = config.get_ref('reference_1') + 100 print(config.reference) # Prints 1. print(config.reference_0) # Prints 11. print(config.reference_1) # Prints 21. print(config.reference_1_0) # Prints 121. config.reference_1 = 30 print(config.reference) # Prints 1 (unchanged). print(config.reference_0) # Prints 11 (unchanged). print(config.reference_1) # Prints 30. print(config.reference_1_0) # Prints 130. ``` #### Cycles You cannot create cycles using references. Fortunately [the only way](#changing-lazily-computed-values) to create a cycle is by assigning a computed field to one that *is not* the result of computation. This is forbidden: ```python import ml_collections from ml_collections.config_dict import config_dict config = ml_collections.ConfigDict() config.integer_field = 1 config.bigger_integer_field = config.get_ref('integer_field') + 10 try: # Raises a MutabilityError because setting config.integer_field would # cause a cycle. config.integer_field = config.get_ref('bigger_integer_field') + 2 except config_dict.MutabilityError as e: print(e) ``` ### Advanced usage Here are some more advanced examples showing lazy computation with different operators and data types. ```python import ml_collections config = ml_collections.ConfigDict() config.float_field = 12.6 config.integer_field = 123 config.list_field = [0, 1, 2] config.float_multiply_field = config.get_ref('float_field') * 3 print(config.float_multiply_field) # Prints 37.8 config.float_field = 10.0 print(config.float_multiply_field) # Prints 30.0 config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] config.list_field = [-1] print(config.longer_list_field) # Prints [-1, 3, 4, 5] # Both operands can be references config.ref_subtraction = ( config.get_ref('float_field') - config.get_ref('integer_field')) print(config.ref_subtraction) # Prints -113.0 config.integer_field = 10 print(config.ref_subtraction) # Prints 0.0 ``` ### Equality checking You can use `==` and `.eq_as_configdict()` to check equality among `ConfigDict` and `FrozenConfigDict` objects. ```python import ml_collections dict_1 = {'list': [1, 2]} dict_2 = {'list': (1, 2)} cfg_1 = ml_collections.ConfigDict(dict_1) frozen_cfg_1 = ml_collections.FrozenConfigDict(dict_1) frozen_cfg_2 = ml_collections.FrozenConfigDict(dict_2) # True because FrozenConfigDict converts lists to tuples print(frozen_cfg_1.items() == frozen_cfg_2.items()) # False because == distinguishes the underlying difference print(frozen_cfg_1 == frozen_cfg_2) # False because == distinguishes these types print(frozen_cfg_1 == cfg_1) # But eq_as_configdict() treats both as ConfigDict, so these are True: print(frozen_cfg_1.eq_as_configdict(cfg_1)) print(cfg_1.eq_as_configdict(frozen_cfg_1)) ``` ### Equality checking with lazy computation Equality checks see if the computed values are the same. Equality is satisfied if two sets of computations are different as long as they result in the same value. ```python import ml_collections cfg_1 = ml_collections.ConfigDict() cfg_1.a = 1 cfg_1.b = cfg_1.get_ref('a') + 2 cfg_2 = ml_collections.ConfigDict() cfg_2.a = 1 cfg_2.b = cfg_2.get_ref('a') * 3 # True because all computed values are the same print(cfg_1 == cfg_2) ``` ### Locking and copying Here is an example with `lock()` and `deepcopy()`: ```python import copy import ml_collections cfg = ml_collections.ConfigDict() cfg.integer_field = 123 # Locking prohibits the addition and deletion of new fields but allows # modification of existing values. cfg.lock() try: cfg.integer_field = 124 # Raises AttributeError and suggests valid field. except AttributeError as e: print(e) with cfg.unlocked(): cfg.integer_field = 1555 # Works fine too. # Get a copy of the config dict. new_cfg = copy.deepcopy(cfg) new_cfg.integer_field = -123 # Works fine. print(cfg) ``` ### Dictionary attributes and initialization ```python import ml_collections referenced_dict = {'inner_float': 3.14} d = { 'referenced_dict_1': referenced_dict, 'referenced_dict_2': referenced_dict, 'list_containing_dict': [{'key': 'value'}], } # We can initialize on a dictionary cfg = ml_collections.ConfigDict(d) # Reference structure is preserved print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2)) # True # And the dict attributes have been converted to ConfigDict print(type(cfg.referenced_dict_1)) # ConfigDict # However, the initialization does not look inside of lists, so dicts inside # lists are not converted to ConfigDict print(type(cfg.list_containing_dict[0])) # dict ``` ### More Examples TODO(mohitreddy): Add links for examples. For more examples, take a look at these `ml_collections/config_dict/examples/`. For examples and gotchas specifically about initializing a ConfigDict, see `ml_collections/config_dict/examples/config_dict_initialization.py`. ## Config Flags This library adds flag definitions to `absl.flags` to handle config files. It does not wrap `absl.flags` so if using any standard flag definitions alongside config file flags, users must also import `absl.flags`. Currently, this module adds two new flag types, namely `DEFINE_config_file` which accepts a path to a Python file that generates a configuration, and `DEFINE_config_dict` which accepts a configuration directly. Configurations are dict-like structures (see [ConfigDict](#configdict)) whose nested elements can be overridden using special command-line flags. See the examples below for more details. ### Usage Use `ml_collections.config_flags` alongside `absl.flags`. For example: `script.py`: ```python from absl import app from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file('my_config') def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run() ``` `config.py`: ```python # Note that this is a valid Python script. # get_config() can return an arbitrary dict-like object. However, it is advised # to use ml_collections.ConfigDict. # See ml_collections/config_dict/examples/config_dict_basic.py import ml_collections def get_config(): config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) return config ``` Now, after running: ```bash python script.py -- --my_config=config.py \ --my_config.field1=8 \ --my_config.nested.field=2.1 \ --my_config.tuple='(1, 2, (1, 2))' ``` we get: ``` field1: 8 field2: tom nested: field: 2.1 tuple: !!python/tuple - 1 - 2 - !!python/tuple - 1 - 2 ``` Usage of `DEFINE_config_dict` is similar to `DEFINE_config_file`, the main difference is the configuration is defined in `script.py` instead of in a separate file. `script.py`: ```python from absl import app from absl import flags import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) FLAGS = flags.FLAGS config_flags.DEFINE_config_dict('my_config', config) def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run() ``` `config_file` flags are compatible with the command-line flag syntax. All the following options are supported for non-boolean values in configurations: * `-(-)config.field=value` * `-(-)config.field value` Options for boolean values are slightly different: * `-(-)config.boolean_field`: set boolean value to True. * `-(-)noconfig.boolean_field`: set boolean value to False. * `-(-)config.boolean_field=value`: `value` is `true`, `false`, `True` or `False`. Note that `-(-)config.boolean_field value` is not supported. ### Parameterising the get_config() function It's sometimes useful to be able to pass parameters into `get_config`, and change what is returned based on this configuration. One example is if you are grid searching over parameters which have a different hierarchical structure - the flag needs to be present in the resulting ConfigDict. It would be possible to include the union of all possible leaf values in your ConfigDict, but this produces a confusing config result as you have to remember which parameters will actually have an effect and which won't. A better system is to pass some configuration, indicating which structure of ConfigDict should be returned. An example is the following config file: ```python import ml_collections def get_config(config_string): possible_structures = { 'linear': ml_collections.ConfigDict({ 'model_constructor': 'snt.Linear', 'model_config': ml_collections.ConfigDict({ 'output_size': 42, }), 'lstm': ml_collections.ConfigDict({ 'model_constructor': 'snt.LSTM', 'model_config': ml_collections.ConfigDict({ 'hidden_size': 108, }) }) } return possible_structures[config_string] ``` The value of `config_string` will be anything that is to the right of the first colon in the config file path, if one exists. If no colon exists, no value is passed to `get_config` (producing a TypeError if `get_config` expects a value.) The above example can be run like: ```bash python script.py -- --config=path_to_config.py:linear \ --config.model_config.output_size=256 ``` or like: ```bash python script.py -- --config=path_to_config.py:lstm \ --config.model_config.hidden_size=512 ``` ### Additional features * Loads any valid python script which defines `get_config()` function returning any python object. * Automatic locking of the loaded object, if the loaded object defines a callable `.lock()` method. * Supports command-line overriding of arbitrarily nested values in dict-like objects (with key/attribute based getters/setters) of the following types: * `types.IntType` (integer) * `types.FloatType` (float) * `types.BooleanType` (bool) * `types.StringType` (string) * `types.TupleType` (tuple) * Overriding is type safe. * Overriding of `TupleType` can be done by passing in the `tuple` as a string (see the example in the [Usage](#usage) section). * The overriding `tuple` object can be of a different size and have different types than the original. Nested tuples are also supported. Platform: UNKNOWN Classifier: Development Status :: 4 - Beta Classifier: Intended Audience :: Developers Classifier: Intended Audience :: Science/Research Classifier: License :: OSI Approved :: Apache Software License Classifier: Programming Language :: Python Classifier: Topic :: Scientific/Engineering Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence Classifier: Topic :: Software Development :: Libraries Classifier: Topic :: Software Development :: Libraries :: Python Modules Requires-Python: >=2.6 Description-Content-Type: text/markdown ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/README.md0000644162426002575230000004224600000000000020763 0ustar00mohitreddyprimarygroup00000000000000# ML Collections ML Collections is a library of Python Collections designed for ML use cases. ## ConfigDict The two classes called `ConfigDict` and `FrozenConfigDict` are "dict-like" data structures with dot access to nested elements. Together, they are supposed to be used as a main way of expressing configurations of experiments and models. This document describes example usage of `ConfigDict`, `FrozenConfigDict`, `FieldReference`. ### Features * Dot-based access to fields. * Locking mechanism to prevent spelling mistakes. * Lazy computation. * FrozenConfigDict() class which is immutable and hashable. * Type safety. * "Did you mean" functionality. * Human readable printing (with valid references and cycles), using valid YAML format. * Fields can be passed as keyword arguments using the `**` operator. * There are two exceptions to the strong type-safety of the ConfigDict. `int` values can be passed in to fields of type `float`. In such a case, the value is type-converted to a `float` before being stored. Similarly, all string types (including Unicode strings) can be stored in fields of type `str` or `unicode`. ### Basic Usage ```python import ml_collections cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg.integer_field) # Prints 123. print(cfg['integer_field']) # Prints 123 as well. try: cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. except TypeError as e: print(e) cfg.float_field = 12 # Works: `Int` types can be assigned to `Float`. cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. print(cfg) ``` ### FrozenConfigDict A `FrozenConfigDict`is an immutable, hashable type of `ConfigDict`: ```python import ml_collections initial_dictionary = { 'int': 1, 'list': [1, 2], 'tuple': (1, 2, 3), 'set': {1, 2, 3, 4}, 'dict_tuple_list': {'tuple_list': ([1, 2], 3)} } cfg = ml_collections.ConfigDict(initial_dictionary) frozen_dict = ml_collections.FrozenConfigDict(initial_dictionary) print(frozen_dict.tuple) # Prints tuple (1, 2, 3) print(frozen_dict.list) # Prints tuple (1, 2) print(frozen_dict.set) # Prints frozenset {1, 2, 3, 4} print(frozen_dict.dict_tuple_list.tuple_list[0]) # Prints tuple (1, 2) frozen_cfg = ml_collections.FrozenConfigDict(cfg) print(frozen_cfg == frozen_dict) # True print(hash(frozen_cfg) == hash(frozen_dict)) # True try: frozen_dict.int = 2 # Raises TypeError as FrozenConfigDict is immutable. except AttributeError as e: print(e) # Converting between `FrozenConfigDict` and `ConfigDict`: thawed_frozen_cfg = ml_collections.ConfigDict(frozen_dict) print(thawed_frozen_cfg == cfg) # True frozen_cfg_to_cfg = frozen_dict.as_configdict() print(frozen_cfg_to_cfg == cfg) # True ``` ### FieldReferences and placeholders A `FieldReference` is useful for having multiple fields use the same value. It can also be used for [lazy computation](#lazy-computation). You can use `placeholder()` as a shortcut to create a `FieldReference` (field) with a `None` default value. This is useful if a program uses optional configuration fields. ```python import ml_collections from ml_collections.config_dict import config_dict placeholder = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict() cfg.placeholder = placeholder cfg.optional = config_dict.placeholder(int) cfg.nested = ml_collections.ConfigDict() cfg.nested.placeholder = placeholder try: cfg.optional = 'tom' # Raises Type error as this field is an integer. except TypeError as e: print(e) cfg.optional = 1555 # Works fine. cfg.placeholder = 1 # Changes the value of both placeholder and # nested.placeholder fields. print(cfg) ``` Note that the indirection provided by `FieldReference`s will be lost if accessed through a `ConfigDict`. ```python import ml_collections placeholder = ml_collections.FieldReference(0) cfg.field1 = placeholder cfg.field2 = placeholder # This field will be tied to cfg.field1. cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. ``` ### Lazy computation Using a `FieldReference` in a standard operation (addition, subtraction, multiplication, etc...) will return another `FieldReference` that points to the original's value. You can use `FieldReference.get()` to execute the operations and get the reference's computed value, and `FieldReference.set()` to change the original reference's value. ```python import ml_collections ref = ml_collections.FieldReference(1) print(ref.get()) # Prints 1 add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 # Addition is lazily computed for FieldReferences so changing ref will change # the value that is used to compute add_ten. ref.set(5) print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 ``` If a `FieldReference` has `None` as its original value, or any operation has an argument of `None`, then the lazy computation will evaluate to `None`. We can also use fields in a `ConfigDict` in lazy computation. In this case a field will only be lazily evaluated if `ConfigDict.get_ref()` is used to get it. ```python import ml_collections config = ml_collections.ConfigDict() config.reference_field = ml_collections.FieldReference(1) config.integer_field = 2 config.float_field = 2.5 # No lazy evaluatuations because we didn't use get_ref() config.no_lazy = config.integer_field * config.float_field # This will lazily evaluate ONLY config.integer_field config.lazy_integer = config.get_ref('integer_field') * config.float_field # This will lazily evaluate ONLY config.float_field config.lazy_float = config.integer_field * config.get_ref('float_field') # This will lazily evaluate BOTH config.integer_field and config.float_Field config.lazy_both = (config.get_ref('integer_field') * config.get_ref('float_field')) config.integer_field = 3 print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value print(config.lazy_integer) # Prints 7.5 config.float_field = 3.5 print(config.lazy_float) # Prints 7.0 print(config.lazy_both) # Prints 10.5 ``` #### Changing lazily computed values Lazily computed values in a ConfigDict can be overridden in the same way as regular values. The reference to the `FieldReference` used for the lazy computation will be lost and all computations downstream in the reference graph will use the new value. ```python import ml_collections config = ml_collections.ConfigDict() config.reference = 1 config.reference_0 = config.get_ref('reference') + 10 config.reference_1 = config.get_ref('reference') + 20 config.reference_1_0 = config.get_ref('reference_1') + 100 print(config.reference) # Prints 1. print(config.reference_0) # Prints 11. print(config.reference_1) # Prints 21. print(config.reference_1_0) # Prints 121. config.reference_1 = 30 print(config.reference) # Prints 1 (unchanged). print(config.reference_0) # Prints 11 (unchanged). print(config.reference_1) # Prints 30. print(config.reference_1_0) # Prints 130. ``` #### Cycles You cannot create cycles using references. Fortunately [the only way](#changing-lazily-computed-values) to create a cycle is by assigning a computed field to one that *is not* the result of computation. This is forbidden: ```python import ml_collections from ml_collections.config_dict import config_dict config = ml_collections.ConfigDict() config.integer_field = 1 config.bigger_integer_field = config.get_ref('integer_field') + 10 try: # Raises a MutabilityError because setting config.integer_field would # cause a cycle. config.integer_field = config.get_ref('bigger_integer_field') + 2 except config_dict.MutabilityError as e: print(e) ``` ### Advanced usage Here are some more advanced examples showing lazy computation with different operators and data types. ```python import ml_collections config = ml_collections.ConfigDict() config.float_field = 12.6 config.integer_field = 123 config.list_field = [0, 1, 2] config.float_multiply_field = config.get_ref('float_field') * 3 print(config.float_multiply_field) # Prints 37.8 config.float_field = 10.0 print(config.float_multiply_field) # Prints 30.0 config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] config.list_field = [-1] print(config.longer_list_field) # Prints [-1, 3, 4, 5] # Both operands can be references config.ref_subtraction = ( config.get_ref('float_field') - config.get_ref('integer_field')) print(config.ref_subtraction) # Prints -113.0 config.integer_field = 10 print(config.ref_subtraction) # Prints 0.0 ``` ### Equality checking You can use `==` and `.eq_as_configdict()` to check equality among `ConfigDict` and `FrozenConfigDict` objects. ```python import ml_collections dict_1 = {'list': [1, 2]} dict_2 = {'list': (1, 2)} cfg_1 = ml_collections.ConfigDict(dict_1) frozen_cfg_1 = ml_collections.FrozenConfigDict(dict_1) frozen_cfg_2 = ml_collections.FrozenConfigDict(dict_2) # True because FrozenConfigDict converts lists to tuples print(frozen_cfg_1.items() == frozen_cfg_2.items()) # False because == distinguishes the underlying difference print(frozen_cfg_1 == frozen_cfg_2) # False because == distinguishes these types print(frozen_cfg_1 == cfg_1) # But eq_as_configdict() treats both as ConfigDict, so these are True: print(frozen_cfg_1.eq_as_configdict(cfg_1)) print(cfg_1.eq_as_configdict(frozen_cfg_1)) ``` ### Equality checking with lazy computation Equality checks see if the computed values are the same. Equality is satisfied if two sets of computations are different as long as they result in the same value. ```python import ml_collections cfg_1 = ml_collections.ConfigDict() cfg_1.a = 1 cfg_1.b = cfg_1.get_ref('a') + 2 cfg_2 = ml_collections.ConfigDict() cfg_2.a = 1 cfg_2.b = cfg_2.get_ref('a') * 3 # True because all computed values are the same print(cfg_1 == cfg_2) ``` ### Locking and copying Here is an example with `lock()` and `deepcopy()`: ```python import copy import ml_collections cfg = ml_collections.ConfigDict() cfg.integer_field = 123 # Locking prohibits the addition and deletion of new fields but allows # modification of existing values. cfg.lock() try: cfg.integer_field = 124 # Raises AttributeError and suggests valid field. except AttributeError as e: print(e) with cfg.unlocked(): cfg.integer_field = 1555 # Works fine too. # Get a copy of the config dict. new_cfg = copy.deepcopy(cfg) new_cfg.integer_field = -123 # Works fine. print(cfg) ``` ### Dictionary attributes and initialization ```python import ml_collections referenced_dict = {'inner_float': 3.14} d = { 'referenced_dict_1': referenced_dict, 'referenced_dict_2': referenced_dict, 'list_containing_dict': [{'key': 'value'}], } # We can initialize on a dictionary cfg = ml_collections.ConfigDict(d) # Reference structure is preserved print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2)) # True # And the dict attributes have been converted to ConfigDict print(type(cfg.referenced_dict_1)) # ConfigDict # However, the initialization does not look inside of lists, so dicts inside # lists are not converted to ConfigDict print(type(cfg.list_containing_dict[0])) # dict ``` ### More Examples TODO(mohitreddy): Add links for examples. For more examples, take a look at these `ml_collections/config_dict/examples/`. For examples and gotchas specifically about initializing a ConfigDict, see `ml_collections/config_dict/examples/config_dict_initialization.py`. ## Config Flags This library adds flag definitions to `absl.flags` to handle config files. It does not wrap `absl.flags` so if using any standard flag definitions alongside config file flags, users must also import `absl.flags`. Currently, this module adds two new flag types, namely `DEFINE_config_file` which accepts a path to a Python file that generates a configuration, and `DEFINE_config_dict` which accepts a configuration directly. Configurations are dict-like structures (see [ConfigDict](#configdict)) whose nested elements can be overridden using special command-line flags. See the examples below for more details. ### Usage Use `ml_collections.config_flags` alongside `absl.flags`. For example: `script.py`: ```python from absl import app from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file('my_config') def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run() ``` `config.py`: ```python # Note that this is a valid Python script. # get_config() can return an arbitrary dict-like object. However, it is advised # to use ml_collections.ConfigDict. # See ml_collections/config_dict/examples/config_dict_basic.py import ml_collections def get_config(): config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) return config ``` Now, after running: ```bash python script.py -- --my_config=config.py \ --my_config.field1=8 \ --my_config.nested.field=2.1 \ --my_config.tuple='(1, 2, (1, 2))' ``` we get: ``` field1: 8 field2: tom nested: field: 2.1 tuple: !!python/tuple - 1 - 2 - !!python/tuple - 1 - 2 ``` Usage of `DEFINE_config_dict` is similar to `DEFINE_config_file`, the main difference is the configuration is defined in `script.py` instead of in a separate file. `script.py`: ```python from absl import app from absl import flags import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) FLAGS = flags.FLAGS config_flags.DEFINE_config_dict('my_config', config) def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run() ``` `config_file` flags are compatible with the command-line flag syntax. All the following options are supported for non-boolean values in configurations: * `-(-)config.field=value` * `-(-)config.field value` Options for boolean values are slightly different: * `-(-)config.boolean_field`: set boolean value to True. * `-(-)noconfig.boolean_field`: set boolean value to False. * `-(-)config.boolean_field=value`: `value` is `true`, `false`, `True` or `False`. Note that `-(-)config.boolean_field value` is not supported. ### Parameterising the get_config() function It's sometimes useful to be able to pass parameters into `get_config`, and change what is returned based on this configuration. One example is if you are grid searching over parameters which have a different hierarchical structure - the flag needs to be present in the resulting ConfigDict. It would be possible to include the union of all possible leaf values in your ConfigDict, but this produces a confusing config result as you have to remember which parameters will actually have an effect and which won't. A better system is to pass some configuration, indicating which structure of ConfigDict should be returned. An example is the following config file: ```python import ml_collections def get_config(config_string): possible_structures = { 'linear': ml_collections.ConfigDict({ 'model_constructor': 'snt.Linear', 'model_config': ml_collections.ConfigDict({ 'output_size': 42, }), 'lstm': ml_collections.ConfigDict({ 'model_constructor': 'snt.LSTM', 'model_config': ml_collections.ConfigDict({ 'hidden_size': 108, }) }) } return possible_structures[config_string] ``` The value of `config_string` will be anything that is to the right of the first colon in the config file path, if one exists. If no colon exists, no value is passed to `get_config` (producing a TypeError if `get_config` expects a value.) The above example can be run like: ```bash python script.py -- --config=path_to_config.py:linear \ --config.model_config.output_size=256 ``` or like: ```bash python script.py -- --config=path_to_config.py:lstm \ --config.model_config.hidden_size=512 ``` ### Additional features * Loads any valid python script which defines `get_config()` function returning any python object. * Automatic locking of the loaded object, if the loaded object defines a callable `.lock()` method. * Supports command-line overriding of arbitrarily nested values in dict-like objects (with key/attribute based getters/setters) of the following types: * `types.IntType` (integer) * `types.FloatType` (float) * `types.BooleanType` (bool) * `types.StringType` (string) * `types.TupleType` (tuple) * Overriding is type safe. * Overriding of `TupleType` can be done by passing in the `tuple` as a string (see the example in the [Usage](#usage) section). * The overriding `tuple` object can be of a different size and have different types than the original. Nested tuples are also supported. ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/ml_collections/0000755162426002575230000000000000000000000022502 5ustar00mohitreddyprimarygroup00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/__init__.py0000644162426002575230000000162700000000000024621 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """ML Collections is a library of Python collections designed for ML usecases.""" from ml_collections.config_dict import ConfigDict from ml_collections.config_dict import FieldReference from ml_collections.config_dict import FrozenConfigDict __all__ = ("ConfigDict", "FieldReference", "FrozenConfigDict") ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/ml_collections/config_dict/0000755162426002575230000000000000000000000024752 5ustar00mohitreddyprimarygroup00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/__init__.py0000644162426002575230000000260100000000000027062 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Classes for defining configurations of experiments and models.""" from .config_dict import _Op from .config_dict import ConfigDict from .config_dict import create from .config_dict import CustomJSONEncoder from .config_dict import FieldReference from .config_dict import FrozenConfigDict from .config_dict import JSONDecodeError from .config_dict import MutabilityError from .config_dict import placeholder from .config_dict import recursive_rename from .config_dict import required_placeholder from .config_dict import RequiredValueError __all__ = ("_Op", "ConfigDict", "create", "CustomJSONEncoder", "FieldReference", "FrozenConfigDict", "JSONDecodeError", "MutabilityError", "placeholder", "recursive_rename", "required_placeholder", "RequiredValueError") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/config_dict.py0000644162426002575230000020411700000000000027601 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Classes for defining configurations of experiments and models. This file defines the classes `ConfigDict` and `FrozenConfigDict`, which are "dict-like" data structures with Lua-like access and some other nice features. Together, they are supposed to be used as a main way of expressing configurations of experiments and models. """ import abc import collections import contextlib import difflib import functools import inspect import json import operator from absl import logging import contextlib2 import six import yaml from yaml import representer # Workaround for https://github.com/yaml/pyyaml/issues/36. Classes that have # `abc.ABCMeta` as a metaclass are incorrectly handled as objects. This results # in the unbound `__reduce_ex__` method being called with the protocol version # as its sole argument, resulting in a `TypeError`. A solution is to add a # custom representer that represents `abc.ABCMeta` by name. representer.Representer.add_representer( data_type=abc.ABCMeta, representer=representer.Representer.represent_name) class RequiredValueError(Exception): pass class MutabilityError(Exception): pass class JSONDecodeError(Exception): pass _NoneType = type(None) def _is_type_safety_violation(value, field_type): """Helper function for type safety exceptions. This function determines whether or not assigning a value to a field violates type safety. Args: value: The value to be assigned. field_type: Type of the field that we would like to assign value to. Returns: True if assigning value to field violates type safety, False otherwise. """ # Allow None to override and be overridden by any type. if value is None or field_type == _NoneType: return False else: return not isinstance(value, field_type) def _safe_cast(value, field_type, type_safe=False): """Helper function to handle the exceptional type conversions. This function implements the following exceptions for type-checking rules: * An `int` will be converted to a `float` if overriding a `float` field. * Any string value can override a `str` or `unicode` field. The value is converted to `field_type`. * A `tuple` will be converted to a `list` if overriding a `list` field. * A `list` will be converted to a `tuple` if overriding `tuple` field. * Short and long integers are indistinguishable. The final value will always be a `long` if both types are present. Args: value: The value to be assigned. field_type: The type for the field that we would like to assign value to. type_safe: If True, the method will throw an error if the `value` is not of type `field_type` after safe type conversions. Returns: The converted type-safe version of the value if it is one of the cases described. Otherwise, return the value without conversion. Raises: TypeError: if types don't match after safe type conversions. """ original_value_type = type(value) # The int->float exception. if isinstance(value, int) and field_type is float: return float(value) # The unicode/string to string exception. if isinstance(value, six.string_types) and field_type in (str, six.text_type): return field_type(value) # tuple<->list conversion. JSON serialization converts lists to tuples, so # we need this to avoid errors when overriding a list field with its # deserialized version. See b/34805906 for more details. if isinstance(value, tuple) and field_type is list: return list(value) if isinstance(value, list) and field_type is tuple: return tuple(value) if isinstance(value, six.integer_types) and field_type in six.integer_types: if six.PY2 and (isinstance(value, long) or field_type is long): # Overriding an int with a long and viceversa should be possible. # https://www.python.org/dev/peps/pep-0237/ return long(value) else: # In Python 3, there is only the `int` type. return value if type_safe and _is_type_safety_violation(value, field_type): raise TypeError('{} is of type {} but should be of type {}' .format(value, str(original_value_type), str(field_type))) return value def _get_computed_value(value_or_fieldreference): if isinstance(value_or_fieldreference, FieldReference): return value_or_fieldreference.get() return value_or_fieldreference class _Op(collections.namedtuple('_Op', ['fn', 'args'])): """A named tuple representing a lazily computed op. The _Op named tuple has two fields: fn: The function to be applied. args: a tuple/list of arguments that are used with the op. """ @functools.total_ordering class FieldReference(object): """Reference to a configuration element. Typed configuration element that can take a None default value. Example: from ml_collections import config_dict cfg_field = config_dict.FieldReference(0) cfg = config_dict.ConfigDict({ 'optional': configdict.FieldReference(None, field_type=str) 'field': cfg_field, 'nested': {'field': cfg_field} }) with self.assertRaises(TypeError): cfg.optional = 10 # Raises an error because it's defined as an # intfield. cfg.field = 1 # Changes the value of both cfg.field and cfg.nested.field. print(cfg) This class also supports lazy computation. Example: ```python ref = config_dict.FieldReference(0) # Using ref in a standard operation returns another FieldReference. The new # reference ref_plus_ten will evaluate ref's value only when we call # ref_plus_ten.get() ref_plus_ten = ref + 10 ref.set(3) # change ref's value print(ref_plus_ten.get()) # Prints 13 because ref's value is 3 ref.set(-2) # change ref's value again print(ref_plus_ten.get()) # Prints 8 because ref's value is -2 ``` """ def __init__(self, default, field_type=None, op=None, required=False): """Creates an instance of FieldReference. Args: default: Default value. field_type: Type for the values contained by the configuration element. If None the type will be inferred from the default value. This value is used as the second argument in calls to isinstance, so it has to follow that function's requirements (class, type or a tuple containing classes, types or tuples). op: An optional operation that is applied to the underlying value when `get()` is called. required: If True, the `get()` method will raise an error if the reference does not contain a value. This argument has no effect when a default value is provided. Setting this to True will raise an error if `op` is not None. Raises: TypeError: If field_type is not None and is different from the type of the default value. ValueError: If both the default value and field_type is None. """ if field_type is None: if default is None: raise ValueError('Default value cannot be None if field_type is None') elif isinstance(default, FieldReference): field_type = default.get_type() else: field_type = type(default) else: # Early error when field_type doesn't a structure compatible with # isinstance (class, type or tuple containing classes, types or tuples. # The easiest way to check this is call isinstance and catch TypeError # exceptions. try: isinstance(None, field_type) except TypeError: raise TypeError('field_type should be a type, not {}' .format(type(field_type))) self._field_type = field_type self.set(default) if required and op is not None: raise ValueError('Cannot set required to True if op is not None') self._required = required self._ops = [] if op is None else [op] def has_cycle(self, visited=None): """Finds cycles in the reference graph. Args: visited: Set containing the ids of all visited nodes in the graph. The default value is the empty set. Returns: True if there is a cycle in the reference graph. """ visited = visited or set() if id(self) in visited: return True visited.add(id(self)) # Verify the reference to the parent FieldReference doesn't introduce a # cycle. value = self._value if isinstance(value, FieldReference) and value.has_cycle(visited.copy()): return True # Verify references in the operator arguments don't introduce cycles. for op in self._ops: for arg in op.args: if isinstance(arg, FieldReference) and arg.has_cycle(visited.copy()): return True return False def set(self, value, type_safe=True): """Overwrites the value pointed by a FieldReference. Args: value: New value. type_safe: Check that old and new values are of the same type. Raises: TypeError: If type_safe is true and old and new values are not of the same type. MutabilityError: If a cycle is found in the reference graph. """ # Disable ops. self._ops = [] if value is None: self._value = None elif isinstance(value, FieldReference): if type_safe and value.get_type() is not self.get_type(): raise TypeError('Reference is of type {} but should be of type {}' .format(value.get_type(), self.get_type())) old_value = getattr(self, '_value', None) self._value = value if self.has_cycle(): self._value = old_value raise MutabilityError('Found cycle in reference graph.') else: # TODO(sergomez): Update reference type. self._value = _safe_cast(value, self._field_type, type_safe) def empty(self): """Returns True if the reference points to a None value.""" return self._value is None def get(self): """Gets the value of the `FieldReference` object. This will dereference `_pointer` and apply all ops to its value. Returns: The result of applying all ops to the dereferenced pointer. Raises: RequiredValueError: if `required` is True and the underlying value for the reference is False. """ if self._required and self._value is None: raise RequiredValueError('None value found in required reference') value = _get_computed_value(self._value) for op in self._ops: # Dereference any FieldReference objects args = [_get_computed_value(arg) for arg in op.args] if value is None or None in args: value = None logging.debug('Cannot apply `%s` to `None`; skipping.', op) else: value = op.fn(value, *args) value = _get_computed_value(value) return value def get_type(self): return self._field_type def __eq__(self, other): if isinstance(other, FieldReference): return self.get() == other.get() else: return self.get() == other def __le__(self, other): if isinstance(other, FieldReference): return self.get() <= other.get() else: return self.get() <= other # Make FieldReference unhashable (as it's mutable). __hash__ = None def _apply_op(self, fn, *args): args = [_safe_cast(arg, self._field_type) for arg in args] return FieldReference( self, field_type=self._field_type, op=_Op(fn, args)) def _apply_cast_op(self, field_type): """Apply a cast op that changes the field_type of this FieldReference. `_apply_op` assumes that the `field_type` does not change after the op is applied whereas `_apply_cast_op` generates a FieldReference with casted field_type. Since `fn(value, *args)` we need to ignore `value` which now contains a dummy default value of field_type. Args: field_type: data type to cast to. Returns: A new FieldReference with of `field_type`. """ return FieldReference( field_type(), # Creates dummy default value matching `field_type`. field_type=field_type, op=_Op(lambda _, val: field_type(val), # `fn` ignores `field_type()`. [self]), ) def identity(self): return self._apply_op(lambda x: x) def __add__(self, other): return self._apply_op(operator.add, other) def __radd__(self, other): radd = functools.partial(operator.add, other) return self._apply_op(radd) def __sub__(self, other): return self._apply_op(operator.sub, other) def __rsub__(self, other): rsub = functools.partial(operator.sub, other) return self._apply_op(rsub) def __mul__(self, other): return self._apply_op(operator.mul, other) def __rmul__(self, other): rmul = functools.partial(operator.mul, other) return self._apply_op(rmul) def __div__(self, other): return self._apply_op(operator.div, other) def __rdiv__(self, other): rdiv = functools.partial(operator.div, other) return self._apply_op(rdiv) def __truediv__(self, other): return self._apply_op(operator.truediv, other) def __rtruediv__(self, other): rtruediv = functools.partial(operator.truediv, other) return self._apply_op(rtruediv) def __floordiv__(self, other): return self._apply_op(operator.floordiv, other) def __rfloordiv__(self, other): rfloordiv = functools.partial(operator.floordiv, other) return self._apply_op(rfloordiv) def __pow__(self, other): return self._apply_op(operator.pow, other) def __mod__(self, other): return self._apply_op(operator.mod, other) def __and__(self, other): return self._apply_op(operator.and_, other) def __or__(self, other): return self._apply_op(operator.or_, other) def __xor__(self, other): return self._apply_op(operator.xor, other) def __neg__(self): return self._apply_op(operator.neg) def __abs__(self): return self._apply_op(operator.abs) def to_int(self): return self._apply_cast_op(int) def to_float(self): return self._apply_cast_op(float) def to_str(self): return self._apply_cast_op(str) def __setstate__(self, state): self._value = state['_value'] self._field_type = state['_field_type'] self._ops = state['_ops'] # TODO(sergomez): Remove default for _required (and potentially the whole # __setstate__ method) after June 2019. self._required = state.get('_required', False) def __nonzero__(self): raise NotImplementedError( 'FieldReference cannot be used for control flow. For boolean ' 'operations use "&" (logical "and") or "|" (logical "or").') def __bool__(self): raise NotImplementedError( 'FieldReference cannot be used for control flow. For boolean ' 'operations use "&" (logical "and") or "|" (logical "or").') def _configdict_fill_seed(seed, initial_dictionary, visit_map=None): """Fills an empty ConfigDict without copying previously visited nodes. Turns seed (an empty ConfigDict) into a ConfigDict version of initial_dictionary. Avoids infinite looping on a self-referencing initial_dictionary because if a value of initial_dictionary has been previously visited, that value is not re-converted to a ConfigDict. If a FieldReference is encountered which contains a dict or FrozenConfigDict, its contents will be converted to ConfigDict. Note: As described in the __init__() documentation, this will not replicate the structure of initial_dictionary if it contains self-references within lists, tuples, or other types. There is no warning or error in this case. Args: seed: Empty ConfigDict, to be filled in. initial_dictionary: The template on which seed is built. May be of type dict, ConfigDict or FrozenConfigDict. visit_map: Dictionary from memory addresses to values, storing the ConfigDict versions of dictionary values. visit_map need not contain (id(initial_dictionary), seed) as a key/value pair. Raises: TypeError: If seed is not a ConfigDict. ValueError: If seed is not an empty ConfigDict. """ # These should be impossible to raise, since the public call-site in # __init__() pass in valid input, as does this method recursively. assert isinstance(seed, ConfigDict) assert not seed visit_map = visit_map or {} visit_map[id(initial_dictionary)] = seed if isinstance(initial_dictionary, ConfigDict): iteritems = initial_dictionary.iteritems(preserve_field_references=True) else: iteritems = six.iteritems(initial_dictionary) for key, value in iteritems: if id(value) in visit_map: value = visit_map[id(value)] elif (isinstance(value, FieldReference) and value.get_type() is dict and seed.convert_dict): # If the reference is empty, we don't have to do dict -> ConfigDict # conversion. # Calling get() on an empty required reference would raise an error so we # need a special case for this. if value.empty(): pass elif id(value.get()) in visit_map: value.set(visit_map[id(value.get())], False) else: value_cd = ConfigDict(type_safe=seed.is_type_safe) _configdict_fill_seed(value_cd, value.get(), visit_map) value.set(value_cd, False) elif isinstance(value, dict) and seed.convert_dict: value_cd = ConfigDict(type_safe=seed.is_type_safe) _configdict_fill_seed(value_cd, value, visit_map) value = value_cd elif isinstance(value, FrozenConfigDict): value = ConfigDict(value) seed.__setattr__(key, value) class ConfigDict(object): # pylint: disable=line-too-long """Base class for configuration objects used in DeepMind. This is a container for configurations. It behaves similarly to Lua tables. Specifically: - it has dot-based access as well as dict-style key access, - it is type safe (once a value is set one cannot change its type). Typical usage eaxmple: from ml_collections import config_dict cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg) Config dictionaries can also be used to pass named arguments to functions: from ml_collections import config_dict def print_point(x, y): print "({},{})".format(x, y) point = config_dict.ConfigDict() point.x = 1 point.y = 2 print_point(**point) Note that, depending on your use case, it may be easier to use the `create` function in this package to construct a `ConfigDict`: from ml_collections.config_dict import config_dict point = config_dict.create(x=1, y=2) Differently from standard `dicts`, `ConfigDicts` also have the nice property that iterating over them is deterministic, in a fashion similar to `collections.OrderedDicts`. """ # pylint: enable=line-too-long def __init__( self, initial_dictionary=None, type_safe=True, convert_dict=True): """Creates an instance of ConfigDict. Warning: In most cases, this faithfully reproduces the reference structure of initial_dictionary, even if initial_dictionary is self-referencing. However, unexpected behavior occurs if self-references are contained within list, tuple, or custom types. For example: d = {} d['a'] = d d['b'] = [d] cd = ConfigDict(d) cd.a # refers to cd, type ConfigDict. Expected behavior. cd.b # refers to d, type dict. Unexpected behavior. Warning: FieldReference values may be changed. If initial_dictionary contains a FieldReference with a value of type dict or FrozenConfigDict, that value is converted to ConfigDict. Args: initial_dictionary: May be one of the following: 1) dict. In this case, all values of initial_dictionary that are dictionaries are also be converted to ConfigDict. However, dictionaries within values of non-dict type are untouched. 2) ConfigDict. In this case, all attributes are uncopied, and only the top-level object (self) is re-addressed. This is the same behavior as Python dict, list, and tuple. 3) FrozenConfigDict. In this case, initial_dictionary is converted to a ConfigDict version of the initial dictionary for the FrozenConfigDict (reversing any mutability changes FrozenConfigDict made). type_safe: If set to True, once an attribute value is assigned, its type cannot be overridden without .ignore_type() context manager (default: True). convert_dict: If set to True, all dict used as value in the ConfigDict will automatically be converted to ConfigDict (default: True). """ if isinstance(initial_dictionary, FrozenConfigDict): initial_dictionary = initial_dictionary.as_configdict() super(ConfigDict, self).__setattr__('_fields', {}) super(ConfigDict, self).__setattr__('_locked', False) super(ConfigDict, self).__setattr__('_type_safe', type_safe) super(ConfigDict, self).__setattr__('_convert_dict', convert_dict) if initial_dictionary is not None: _configdict_fill_seed(self, initial_dictionary) if isinstance(initial_dictionary, ConfigDict): super(ConfigDict, self).__setattr__('_locked', initial_dictionary.is_locked) super(ConfigDict, self).__setattr__('_type_safe', initial_dictionary.is_type_safe) @property def is_type_safe(self): """Returns True if config dict is type safe.""" return self._type_safe @property def convert_dict(self): """Returns True if it is converting dicts to ConfigDict automatically.""" return self._convert_dict def lock(self): """Locks object, preventing user from adding new fields. Returns: self """ if self.is_locked: return self super(ConfigDict, self).__setattr__('_locked', True) for field in self._fields: element = self._fields[field] element = _get_computed_value(element) if isinstance(element, ConfigDict): element.lock() return self @property def is_locked(self): """Returns True if object is locked.""" return self._locked def unlock(self): """Grants user the ability to add new fields to ConfigDict. In most cases, the unlocked() context manager should be preferred to the direct use of the unlock method. Returns: self """ super(ConfigDict, self).__setattr__('_locked', False) for element in six.itervalues(self._fields): element = _get_computed_value(element) if isinstance(element, ConfigDict) and element.is_locked: element.unlock() return self def get(self, key, default=None): """Returns value if key is present, or a user defined value otherwise.""" try: return self[key] except KeyError: return default # TODO(sergomez): replace this with get_oneway_ref. The first step is to log # usage patterns of this. How many users are overriding the value of the # reference returned by this and expect the referenced field to change too? def get_ref(self, key): """Returns a FieldReference initialized on key's value.""" field = self._fields[key] if field is None: raise ValueError('Cannot create reference to a field whose value is None') if not isinstance(field, FieldReference): field = FieldReference(field) with self.ignore_type(): self[key] = field return field def get_oneway_ref(self, key): """Returns a one-way FieldReference. Example: ``` cfg = collections.ConfigDict(dict(a=1)) cfg.b = cfg.get_oneway_ref('a') cfg.a = 2 print(cfg.b) # 2 cfg.b = 3 print(cfg.a) # 2 (would have been 3 if using get_ref()) print(cfg.b) # 3 ``` Args: key: Key for field we want to reference. """ # Using the result of applying an operation on the reference means that # calling set() on this object won't propagate the new value up the # reference chain. return self.get_ref(key).identity() def items(self, preserve_field_references=False): """Returns list of dictionary key, value pairs, sorted by key. Args: preserve_field_references: (bool) Whether to preserve FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved in the result. Returns: The key, value pairs in the config, sorted by key. """ if preserve_field_references: return six.iteritems(self._ordered_fields) else: return [(k, self[k]) for k in self._ordered_fields] @property def _ordered_fields(self): """Returns ordered dict shallow cast of _fields member.""" return collections.OrderedDict(sorted(self._fields.items())) def iteritems(self, preserve_field_references=False): """Deterministically iterates over dictionary key, value pairs. Args: preserve_field_references: (bool) Whether to preserve FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved in the result. Yields: The key, value pairs in the config, sorted by key. """ for k in self._ordered_fields: if preserve_field_references: yield k, self._fields[k] else: yield k, self[k] def _ensure_mutability(self, attribute): if attribute in dir(super(ConfigDict, self)): raise KeyError('{} cannot be overridden.'.format(attribute)) def __setattr__(self, attribute, value): try: self._ensure_mutability(attribute) self[attribute] = value except KeyError as e: raise AttributeError(e) def __delattr__(self, attribute): try: self._ensure_mutability(attribute) del self[attribute] except KeyError as e: raise AttributeError(e) def __getattr__(self, attribute): try: return self[attribute] except KeyError as e: raise AttributeError(e) def __setitem__(self, key, value): if '.' in key: raise ValueError('ConfigDict does not accept dots in field names, but ' 'the key {} contains one.'.format(key)) if self.is_locked and key not in self._fields: error_msg = ('Key "{}" does not exist and cannot be added since the ' 'config is locked') raise KeyError( self._generate_did_you_mean_message(key, error_msg.format(key))) if key in self._fields: field = self._fields[key] try: if isinstance(field, FieldReference): field.set(value, self._type_safe) return # Skip type checking if the value is a FieldReference of the same type. if (not isinstance(value, FieldReference) or value.get_type() is not type(field)): if isinstance(value, dict) and self._convert_dict: value = ConfigDict(value, self._type_safe) value = _safe_cast(value, type(field), self._type_safe) except TypeError as e: raise TypeError('Could not override field \'{}\' (reference). {}' .format(key, str(e))) if isinstance(value, dict) and self._convert_dict: value = ConfigDict(value, self._type_safe) elif isinstance(value, FieldReference): # TODO(sergomez): We should consider using value.get_type(). ref_type = _NoneType if value.empty() else type(value.get()) if ref_type is dict or ref_type is FrozenConfigDict: value_cd = ConfigDict(value.get(), self._type_safe) value.set(value_cd, False) self._fields[key] = value def _generate_did_you_mean_message(self, request, message=''): matches = difflib.get_close_matches(request, self.keys(), 1, 0.75) if matches: if message: message += '\n' message += 'Did you mean "{}" instead of "{}"?'.format(matches[0], request) return message def __delitem__(self, key): if self.is_locked: raise KeyError('This ConfigDict is locked, you have to unlock it before ' 'trying to delete a field.') if '.' in key: # As per the check in __setitem__ above, keys cannot contain dots. # Hence, we can use dots to do recursive calls. key, rest = key.split('.', 1) del self[key][rest] return try: del self._fields[key] except KeyError as e: raise KeyError(self._generate_did_you_mean_message(key, str(e))) def __getitem__(self, key): if '.' in key: # As per the check in __setitem__ above, keys cannot contain dots. # Hence, we can use dots to do recursive calls. key, rest = key.split('.', 1) return self[key][rest] try: field = self._fields[key] if isinstance(field, FieldReference): return field.get() else: return field except KeyError as e: raise KeyError(self._generate_did_you_mean_message(key, str(e))) def __contains__(self, key): return key in self._fields def __repr__(self): return yaml.dump(self.to_dict(preserve_field_references=True), default_flow_style=False) def __str__(self): return yaml.dump(self.to_dict()) def keys(self): """Returns the sorted list of all the keys defined in a config.""" return list(self._ordered_fields.keys()) def iterkeys(self): """Deterministically iterates over dictionary keys, in sorted order.""" return six.iterkeys(self._ordered_fields) def values(self, preserve_field_references=False): """Returns the list of all values in a config, sorted by their keys. Args: preserve_field_references: (bool) Whether to preserve FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved in the result. Returns: The values in the config, sorted by their corresponding keys. """ if preserve_field_references: return list(self._ordered_fields.values()) else: return [self[k] for k in self._ordered_fields] def itervalues(self, preserve_field_references=False): """Deterministically iterates over values in a config, sorted by their keys. Args: preserve_field_references: (bool) Whether to preserve FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved in the result. Yields: The values in the config, sorted by their corresponding keys. """ for k in self._ordered_fields: if preserve_field_references: yield self._fields[k] else: yield self[k] def __dir__(self): return self.keys() + dir(ConfigDict) def __len__(self): return self._ordered_fields.__len__() def __iter__(self): return self._ordered_fields.__iter__() def __eq__(self, other): """Override the default Equals behavior.""" if isinstance(other, self.__class__): same = self._fields == other._fields same &= self._locked == other.is_locked same &= self._type_safe == other.is_type_safe return same return False def __ne__(self, other): """Define a non-equality test.""" return not self.__eq__(other) def eq_as_configdict(self, other): """Type-invariant equals. This is like `__eq__`, except it does not distinguish `FrozenConfigDict` from `ConfigDict`. For example: cd = ConfigDict() fcd = FrozenConfigDict() fcd.eq_as_configdict(cd) # Returns True Args: other: Object to compare self to. Returns: same: `True` if `self == other` after conversion to `ConfigDict`. """ if isinstance(other, ConfigDict): return ConfigDict(self) == ConfigDict(other) else: return False # Make ConfigDict unhashable __hash__ = None def to_yaml(self, **kwargs): """Returns a YAML representation of the object. Args: **kwargs: Keyword arguments for yaml.dump. Returns: YAML representation of the object. """ return yaml.dump(self, **kwargs) def _json_dumps_wrapper(self, **kwargs): """Wrapper for json.dumps() method. Produces a more informative error message when there is a problem with string encodings in the ConfigDict. Args: **kwargs: Keyword arguments for json.dumps. Returns: JSON representation of the object. Raises: JSONDecodeError: If there is a problem with string encodings. """ try: return json.dumps(self._fields, **kwargs) except UnicodeDecodeError as error: # Re-raise exception with more informative error message. new_message = ( 'Decoding error. Make sure all strings in your ConfigDict use ASCII-' 'compatible encodings. See ' 'https://docs.python.org/2.7/howto/unicode.html#the-unicode-type ' 'for details. Original error message: {}'.format(error)) raise JSONDecodeError(new_message) def to_json(self, json_encoder_cls=None, **kwargs): """Returns a JSON representation of the object, fails if there is a cycle. Args: json_encoder_cls: An optional JSON encoder class to customize JSON serialization. **kwargs: Keyword arguments for json.dumps. They cannot contain "cls" as this method specifies it on its own. Returns: JSON representation of the object. Raises: TypeError: If self contains set, frozenset, custom type fields or any other objects that are not JSON serializable. """ json_encoder_cls = json_encoder_cls or CustomJSONEncoder return self._json_dumps_wrapper(cls=json_encoder_cls, **kwargs) def to_json_best_effort(self, **kwargs): """Returns a best effort JSON representation of the object. Tries to serialize objects not inherently supported by JSON encoder. This may result in the configdict being partially serialized, skipping the unserializable bits. Ensures that no errors are thrown. Fails if there is a cycle. Args: **kwargs: Keyword arguments for json.dumps. They cannot contain "cls" as this method specifies it on its own. Returns: JSON representation of the object. """ return self._json_dumps_wrapper(cls=_BestEffortCustomJSONEncoder, **kwargs) def to_dict(self, visit_map=None, preserve_field_references=False): """Converts ConfigDict to regular dict recursively with valid references. By default, the output dict will not contain FieldReferences, any present in the ConfigDict will be resolved. However, if `preserve_field_references` is True, the output dict will contain FieldReferences where the original ConfigDict has them. They will not be the same as the ConfigDict's, and their ops will be applied and dropped. Note: As with __eq__() and __init__(), this may not behave as expected on a ConfigDict with self-references contained in lists, tuples, or custom types. Args: visit_map: A mapping from object ids to their dict representation. Method is recursive in nature, and it will call ".to_dict(visit_map)" on each encountered object, unless it is already in visit_map. preserve_field_references: (bool) Whether the output dict should have FieldReferences if the ConfigDict has them. By default, False: any FieldReferences will be resolved and the result will go to the dict. Returns: Dictionary with the same values and references structure as a calling ConfigDict. """ visit_map = visit_map or {} dict_self = {} visit_map[id(self)] = dict_self for key in self: if (isinstance(self._fields[key], FieldReference) and preserve_field_references): reference = self._fields[key] value = reference.get() else: value = self[key] reference = value if id(reference) in visit_map: dict_self[key] = visit_map[id(reference)] elif isinstance(value, ConfigDict): if isinstance(reference, FieldReference): # Create a new reference of type dict instead of ConfigDict. old_reference = reference reference = FieldReference({}, dict) visit_map[id(old_reference)] = reference reference.set(value.to_dict(visit_map, preserve_field_references)) else: reference = value.to_dict(visit_map, preserve_field_references) dict_self[key] = reference else: if isinstance(reference, FieldReference): # Create a new reference to put in the new dict, which will be # reused whenever we find the same original reference. # Notice that ops are lost in the copy, but they are applied when # we do old_reference.get(). old_reference = reference # Disable type safety since value in the field reference might have # been previously set with type safety disabled (e.g. ignore_type # context, as in b/119393923). reference = FieldReference(None, old_reference.get_type()) reference.set(old_reference.get(), type_safe=False) visit_map[id(old_reference)] = reference dict_self[key] = reference return dict_self def copy_and_resolve_references(self, visit_map=None): """Returns a ConfigDict copy with FieldReferences replaced by values. If the object is a FrozenConfigDict, the copy returned is also a FrozenConfigDict. However, note that FrozenConfigDict should already have FieldReferences resolved to values, so this method effectively produces a deep copy. Note: As with __eq__() and __init__(), this may not behave as expected on a ConfigDict with self-references contained in lists, tuples, or custom types. Args: visit_map: A mapping from ConfigDict object ids to their copy. Method is recursive in nature, and it will call ".copy_and_resolve_references(visit_map)" on each encountered object, unless it is already in visit_map. Returns: ConfigDict copy with previous FieldReferences replaced by values. """ visit_map = visit_map or {} config_dict_copy = self.__class__() visit_map[id(self)] = config_dict_copy for key, value in six.iteritems(self._fields): if isinstance(value, FieldReference): value = value.get() if id(value) in visit_map: value = visit_map[id(value)] elif isinstance(value, ConfigDict): value = value.copy_and_resolve_references(visit_map) if isinstance(self, FrozenConfigDict): config_dict_copy._frozen_setattr( # pylint:disable=protected-access key, value) else: config_dict_copy[key] = value super(ConfigDict, config_dict_copy).__setattr__( '_locked', self.is_locked) super(ConfigDict, config_dict_copy).__setattr__( '_type_safe', self.is_type_safe) return config_dict_copy def __setstate__(self, state): """Recreates ConfigDict from its dict representation.""" self.__init__() super(ConfigDict, self).__setattr__('_type_safe', state['_type_safe']) super(ConfigDict, self).__setattr__('_convert_dict', state.get('_convert_dict', True)) for field in state['_fields']: self[field] = state['_fields'][field] if state['_locked']: self.lock() @contextlib.contextmanager def unlocked(self): """Context manager which temporarily unlocks a ConfigDict.""" was_locked = self._locked if was_locked: self.unlock() try: yield self finally: if was_locked: self.lock() @contextlib.contextmanager def ignore_type(self): """Context manager which temporarily turns off type safety recursively.""" original_type_safety = self._type_safe managers = [] visited = set() fields = list(six.iteritems(self._fields)) while fields: field = fields.pop() if id(field) in visited: continue visited.add(id(field)) if isinstance(field, ConfigDict): managers.append(field.ignore_type) # Recursively add elements in nested collections. elif isinstance(field, collections.Mapping): fields.extend(six.iteritems(field)) elif isinstance(field, (collections.Sequence, collections.Set)): fields.extend(field) super(ConfigDict, self).__setattr__('_type_safe', False) try: with contextlib2.ExitStack() as stack: for manager in managers: stack.enter_context(manager()) yield self finally: super(ConfigDict, self).__setattr__('_type_safe', original_type_safety) def get_type(self, key): """Returns type of the field associated with a key.""" # We access the field outside of the if/else statement to raise in all cases # AttributeErrors (potentially including "did you mean" messages) for # non-existent keys. field = self.__getattr__(key) if isinstance(self._fields[key], FieldReference): return self._fields[key].get_type() else: return type(field) def update(self, *other, **kwargs): """Update values based on matching keys in another dict-like object. Mimics the built-in dict's update method: iterates over a given mapping object and adds/overwrites keys with the given mapping's values for those keys. Differs from dict.update in that it operates recursively on existing keys that are already a ConfigDict (i.e. calls their update() on the corresponding value from other), and respects the ConfigDict's type safety status. If keyword arguments are specified, the ConfigDict is updated with those key/value pairs. Args: *other: A (single) dict-like container, e.g. a dict or ConfigDict. **kwargs: Additional keyword arguments to update the ConfigDict. Raises: TypeError: if more than one value for `other` is specified. """ if len(other) > 1: raise TypeError( 'update expected at most 1 arguments, got {}'.format(len(other))) for other in other + (kwargs,): iteritems_kwargs = {} if isinstance(other, ConfigDict): iteritems_kwargs['preserve_field_references'] = True for key, value in six.iteritems(other, **iteritems_kwargs): if key not in self: self[key] = value elif isinstance(self._fields[key], ConfigDict): self[key].update(other[key]) elif (isinstance(self._fields[key], FieldReference) and isinstance(value, FieldReference)): # Updating FieldReferences from FieldReferences is not allowed. # One option could be to just grab the value from `other` and try to # update the reference in `self` using that. But that could result in # losing links between fields in `other`. # Example: # other = ConfigDict(dict(a=1)) # other.b = other.get_ref('a') # this = ConfigDict(dict(a=2)) # this.c = this.get_ref('a') # # # Say we update `this` with `other`. The links between fields # # in `other` could be lost in `this`. # this.update(other) # # # It is unclear what `this.b` should be when `this.a` is updated. # this.a = 10 # # this.b? raise TypeError('Cannot update a FieldReference from another ' 'FieldReference') else: self[key] = value def update_from_flattened_dict(self, flattened_dict, strip_prefix=''): """In-place updates values taken from a flattened dict. This allows a potentially nested source `ConfigDict` of the following form: cfg = ConfigDict({ 'a': 1, 'b': { 'c': { 'd': 2 } } }) to be updated from a dict containing paths navigating to child items, of the following form: updates = { 'a': 2, 'b.c.d': 3 } This filters `paths_dict` to only contain paths starting with `strip_prefix` and strips the prefix when applying the update. For example, consider we have the following values returned as flags: flags = { 'flag1': x, 'flag2': y, 'config': 'some_file.py', 'config.a.b': 1, 'config.a.c': 2 } config = ConfigDict({ 'a': { 'b': 0, 'c': 0 } }) config.update_from_flattened_dict(flags, 'config.') Then we will now have: config = ConfigDict({ 'a': { 'b': 1, 'c': 2 } }) Args: flattened_dict: A mapping (key path) -> value. strip_prefix: A prefix to be stripped from `path`. If specified, only paths matching `strip_prefix` will be processed. Raises: KeyError: if any of the key paths can't be found. """ if strip_prefix: interesting_items = { key: value for key, value in six.iteritems(flattened_dict) if key.startswith(strip_prefix) } else: interesting_items = flattened_dict # Keep track of any children that we want to update. Make sure that we # recurse into each one only once. children_to_update = set() for full_key, value in six.iteritems(interesting_items): key = full_key[len(strip_prefix):] if strip_prefix else full_key if '.' in key: # If the path is hierarchical, we'll need to tell the first component # to update itself. child = key.split('.')[0] if child in self: if isinstance(self[child], ConfigDict): children_to_update.add(child) else: raise KeyError('Key "{}" cannot be updated as "{}" is not a ' 'ConfigDict.'.format(full_key, strip_prefix + child)) else: raise KeyError('Key "{}" cannot be set as "{}" was not found.' .format(full_key, strip_prefix + child)) else: self[key] = value for child in children_to_update: child_strip_prefix = strip_prefix + child + '.' self[child].update_from_flattened_dict(interesting_items, child_strip_prefix) def _frozenconfigdict_valid_input(obj, ancestor_list=None): """Raises error if obj is NOT a valid input for FrozenConfigDict. Args: obj: Object to check. In the first call (with ancestor_list = None), obj should be of type ConfigDict. During recursion, it may be any type except dict. ancestor_list: List of ancestors of obj in the attribute/element structure, used to detect reference cycles in obj. Raises: ValueError: If obj is an invalid input for FrozenConfigDict, i.e. if it contains a dict within a list/tuple or contains a reference cycle. Also if obj is a dict, which means it wasn't already converted to ConfigDict. """ ancestor_list = ancestor_list or [] # Dicts must be converted to ConfigDict before _frozenconfigdict_valid_input() assert not isinstance(obj, dict) if id(obj) in ancestor_list: raise ValueError('Bad FrozenConfigDict initialization: Cannot contain a ' 'cycle in its reference structure.') ancestor_list.append(id(obj)) if isinstance(obj, ConfigDict): for value in obj.values(): _frozenconfigdict_valid_input(value, ancestor_list) elif isinstance(obj, FieldReference): _frozenconfigdict_valid_input(obj.get(), ancestor_list) elif isinstance(obj, (list, tuple)): for element in obj: if isinstance(element, (dict, ConfigDict, FieldReference)): raise ValueError('Bad FrozenConfigDict initialization: Cannot ' 'contain a dict, ConfigDict, or FieldReference ' 'within a list or tuple.') _frozenconfigdict_valid_input(element, ancestor_list) ancestor_list.pop() def _tuple_to_immutable(value, visit_map): """Convert tuple to fully immutable tuple. Args: value: Tuple to be made fully immutable (including its elements). visit_map: As used elsewhere. See _frozenconfigdict_fill_seed() documentation. Must not contain id(value) as a key (if it does, an immutable version of value already exists). Returns: immutable_value: Immutable version of value, created with minimal copying (for example, if a value contains no mutable elements, it is returned untouched). same_value: Whether the same value was returned untouched, i.e. with the same memory address. Boolean. visit_map: Updated visit_map Raises: TypeError: If one of the following: 1) value is not a tuple. 2) value contains a dict, ConfigDict, or FieldReference. If it does, value is an invalid attribute of FrozenConfigDict, and this should have been caught in valid_input at initialization. ValueError: id(value) is in visit_map. """ # Ensure there are no cycles assert id(value) not in visit_map value_copy = [] same_value = True for element in value: # Sanity check: element cannot be dict, ConfigDict, or FieldReference assert not isinstance(element, (dict, ConfigDict, FieldReference)) if isinstance(element, (list, tuple, set)): new_element, uncopied_element, visit_map = _convert_to_immutable( element, visit_map) same_value &= uncopied_element value_copy.append(new_element) else: value_copy.append(element) if same_value: return value, True, visit_map else: return tuple(value_copy), False, visit_map def _convert_to_immutable(value, visit_map): """Convert Python built-in type to immutable, copying if necessary. Args: value: To be made immutable type (including its elements). Must have type list, tuple, or set. visit_map: As used elsewhere. See _frozenconfigdict_fill_seed() documentation. Returns: immutable_value: Immutable version of value, created with minimal copying. same_value: Whether the same value was returned untouched, i.e. with the same memory address. Boolean. visit_map: Updated visit_map. Raises: TypeError: If value is an invalid type (not a list, tuple, or set). """ value_id = id(value) if value_id in visit_map: return visit_map[value_id], True, visit_map same_value = False if isinstance(value, set): immutable_value = frozenset(value) elif isinstance(value, tuple): immutable_value, same_value, visit_map = _tuple_to_immutable( value, visit_map) elif isinstance(value, list): immutable_value, _, visit_map = _tuple_to_immutable(tuple(value), visit_map) else: # Type-check the input assert False visit_map[value_id] = immutable_value return immutable_value, same_value, visit_map def _frozenconfigdict_fill_seed(seed, initial_configdict, visit_map=None): """Fills an empty FrozenConfigDict without copying previously visited nodes. Turns seed (an empty FrozenConfigDict) into a FrozenConfigDict version of initial_configdict. Avoids duplicating nodes of initial_configdict because if a value of initial_configdict has been previously visited, that value is not re-converted to a FrozenConfigDict. If a FieldReference is encountered which contains a dict, its contents will be converted to FrozenConfigDict. Note: As described in the __init__() documentation, this will not replicate the structure of initial_configdict if it contains self-references within lists, tuples, or other types. There is no warning or error in this case. Args: seed: Empty FrozenConfigDict, to be filled in. initial_configdict: The template on which seed is built. Must be of type ConfigDict. visit_map: Dictionary from memory addresses to values, storing the FrozenConfigDict versions of dictionary values. Lists which have been converted to tuples and sets to frozensets are also stored in visit_map to preserve the reference structure of initial_configdict. visit_map need not contain (id(initial_configdict), seed) as a key/value pair. Raises: ValueError: If one of the following, both of which can never happen in practice: 1) seed is not an empty FrozenConfigDict. 2) initial_configdict contains a FieldReference. """ # These should be impossible to raise, since the public call-site in # __init__() pass in valid input, as does this method recursively. assert isinstance(seed, FrozenConfigDict) assert not seed # This is where we define self._configdict for the FrozenConfigDict() class. # It is defined here instead of in FrozenConfigDict.__init__() because we must # fill in an empty FrozenConfigDict but do not want to have an unexpected # signature for FrozenConfigDict.__init__() by passing it initial_configdict. object.__setattr__(seed, '_configdict', initial_configdict) visit_map = visit_map or {} visit_map[id(initial_configdict)] = seed for key, value in six.iteritems(initial_configdict): # This should never be raised due to elimination of references by # ConfigDict's iteritems if isinstance(value, FieldReference): raise ValueError('Trying to initialize a FrozenConfigDict value with ' 'a FieldReference. This should never happen, please ' 'file a bug.') if id(value) in visit_map: value = visit_map[id(value)] elif (isinstance(value, ConfigDict) and not isinstance(value, FrozenConfigDict)): value_frozenconfigdict = FrozenConfigDict(type_safe=value.is_type_safe) _frozenconfigdict_fill_seed(value_frozenconfigdict, value, visit_map) value = value_frozenconfigdict seed._frozen_setattr(key, value, # pylint:disable=protected-access visit_map) class FrozenConfigDict(ConfigDict): """Immutable and hashable type of ConfigDict. See ConfigDict() documentation above for details and usage. FrozenConfigDict is fully immutable. It contains no lists or sets (at initialization, lists and sets are converted to tuples and frozensets). The only potential sources of mutability are attributes with custom types, which are not touched. It is recommended to convert a ConfigDict to FrozenConfigDict after construction if possible. """ def __init__(self, initial_dictionary=None, type_safe=True): """Creates an instance of FrozenConfigDict. Lists and sets are copied into tuples and frozensets. However, copying is kept to a minimum so tuples, frozensets, and other immutable types are not copied unless they contain mutable types. Prohibited initial_dictionary structures: initial_dictionary may not contain any lists or tuples with dictionary, ConfigDict, or FieldReference elements, or else an error is raised at initialization. It also may not contain loops in the reference structure, i.e. the reference structure must be a Directed Acyclic Graph. This includes loops in list-element and tuple-element references. initial_dictionary's reference structure need not be a tree. Warning: Unexpected behavior may occur with types other than Python's built-in types. See ConfigDict() documentation for details. Warning: As with ConfigDict, FieldReference values may be changed. If initial_dictionary contains a FieldReference with a value of type dict or ConfigDict, that value will be converted to FrozenConfigDict. Args: initial_dictionary: May be one of the following: 1) dict. In this case all values of initial_dictionary that are dictionaries are also converted to FrozenConfigDict. If there are dictionaries contained in lists or tuples, an error is raised. 2) ConfigDict. In this case all ConfigDict attributes are also converted to FrozenConfigDict. 3) FrozenConfigDict. In this case all attributes are uncopied, and only the top-level object (self) is re-addressed. type_safe: See ConfigDict documentation. Note that this only matters if the FrozenConfigDict is converted to ConfigDict at some point. """ super(FrozenConfigDict, self).__init__() initial_configdict = ConfigDict(initial_dictionary=initial_dictionary, type_safe=type_safe) _frozenconfigdict_valid_input(initial_configdict) # This will define the self._configdict attribute _frozenconfigdict_fill_seed(self, initial_configdict) object.__setattr__(self, '_locked', initial_configdict.is_locked) object.__setattr__(self, '_type_safe', initial_configdict.is_type_safe) def _frozen_setattr(self, key, value, visit_map=None): """Sets attribute, analogous to __setattr__(). Args: key: Name of the attribute to set. value: Value of the attribute to set. visit_map: Dictionary from memory addresses to values, storing the FrozenConfigDict versions of value's elements. Lists which have been converted to tuples and sets to frozensets are also stored in visit_map. Returns: visit_map: Updated visit_map. Raises: ValueError: If there is a dot in key, or value contains dicts inside lists or tuples. Also if key is already an attribute, since redefining an attribute is prohibited for FrozenConfigDict. AttributeError: If key is protected (such as '_type_safe' and '_locked'). """ visit_map = visit_map or {} # These should always pass because of conversion to ConfigDict at # initialization self._ensure_mutability(key) assert '.' not in key if key in self._fields: raise ValueError('Cannot redefine attribute of FrozenConfigDict.') if isinstance(value, (list, tuple, set)): immutable_value, _, visit_map = _convert_to_immutable(value, visit_map) self._fields[key] = immutable_value else: self._fields[key] = value return visit_map def __setstate__(self, state): """Recreates FrozenConfigDict from its dict representation.""" self.__init__(state['_configdict']) def __setattr__(self, attribute, value): raise AttributeError('FrozenConfigDict is immutable. Cannot call ' '__setattr__().') def __delattr__(self, attribute): raise AttributeError('FrozenConfigDict is immutable. Cannot call ' '__delattr__().') def __setitem__(self, attribute, value): raise AttributeError('FrozenConfigDict is immutable. Cannot call ' '__setitem__().') def __delitem__(self, attribute): raise AttributeError('FrozenConfigDict is immutable. Cannot call ' '__delitem__().') def lock(self): raise AttributeError('FrozenConfigDict is immutable. Cannot call lock().') def unlock(self): raise AttributeError('FrozenConfigDict is immutable. Cannot call unlock().') def __hash__(self): """Computes hash. The hash depends not only on the immutable aspects of the FrozenConfigDict, but on the types of the initial_dictionary at initialization (i.e. on the _configdict attribute). For example, in the following, hash(frozen_1) will not equal hash(frozen_2): d_1 = {'a': (1, )} d_2 = {'a': [1]} frozen_1 = FrozenConfigDict(d_1) frozen_2 = FrozenConfigDict(d_2) Note: This implementation relies on the particulars of the FrozenConfigDict structure. For example, the fact that lists and tuples cannot contain dicts or ConfigDicts is crucial, as is the fact that any loop in the reference structure is prohibited. Note: Due to hash randomization, this hash will likely differ in different Python sessions. For comparisons across sessions, please directly use equality of the serialization. For more, see https://bugs.python.org/issue13703 Returns: frozenconfigdict_hash: The hash value. Raises: TypeError: self contains an unhashable type. """ def value_hash(value): """Hashes a single value.""" if isinstance(value, set): return hash((hash(frozenset(value)), 1)) elif isinstance(value, (list, tuple)): value_hash_list = [isinstance(value, list)] for element in value: element_hash = value_hash(element) value_hash_list.append(element_hash) return hash(tuple(value_hash_list)) elif isinstance(value, FieldReference): return value_hash(value.get()) else: try: return hash(value) except TypeError: raise TypeError('FrozenConfigDict contains unhashable type ' '{}'.format(type(value))) fields_hash = 0 for key, value in six.iteritems(self._fields): if isinstance(value, FrozenConfigDict): fields_hash += hash((hash(key), hash(value))) else: # Use self._configdict value to ensure attending to mutability fields_hash += hash((hash(key), value_hash(self._configdict._fields[key]))) frozenconfigdict_hash = hash((fields_hash, self.is_locked, self.is_type_safe)) return frozenconfigdict_hash def __eq__(self, other): """Override default Equals behavior. Like __hash__(), this pays attention to the type of initial_dictionary. See __hash__() documentation for details. Warning: This distinguishes FrozenConfigDict from ConfigDict. For example: cd = ConfigDict() fcd = FrozenConfigDict() fcd.__eq__(cd) # Returns False Args: other: Object to compare self to. Returns: same: Boolean self == other. """ if isinstance(other, FrozenConfigDict): return ConfigDict(self) == ConfigDict(other) else: return False def as_configdict(self): return self._configdict class CustomJSONEncoder(json.JSONEncoder): """JSON encoder for ConfigDict and FieldReference. The encoder throws an exception for non-supported types. """ def default(self, obj): if isinstance(obj, FieldReference): return obj.get() elif isinstance(obj, ConfigDict): return obj._fields else: raise TypeError('{} is not JSON serializable. Instead use ' 'ConfigDict.to_json_best_effort()'.format(type(obj))) class _BestEffortCustomJSONEncoder(CustomJSONEncoder): """Best effort JSON encoder for ConfigDict. The encoder tries to serialize non-supported types (doesn't throw exceptions). """ def default(self, obj): try: return super(_BestEffortCustomJSONEncoder, self).default(obj) except TypeError: if isinstance(obj, set): return sorted(list(obj)) elif inspect.isfunction(obj): return 'function {}'.format(obj.__name__) elif (hasattr(obj, '__dict__') and obj.__dict__ and not inspect.isclass(obj)): # Instantiated object's variables return dict(obj.__dict__) elif hasattr(obj, '__str__'): return 'unserializable object: {}'.format(obj) else: return 'unserializable object of type: {}'.format(type(obj)) def create(**kwargs): """Creates a `ConfigDict` with the given named arguments as key-value pairs. This allows for simple dictionaries whose elements can be accessed directly using field access: from ml_collections.config_dict import config_dict point = config_dict.create(x=1, y=2) print(point.x, point.y) This is particularly useful for compactly writing nested configurations: config = config_dict.create( data=config_dict.create( game='freeway', frame_size=100), model=config_dict.create(num_hidden=1000)) The reason for the existence of this function is that it simplifies the code required for the majority of the use cases of `ConfigDict`, compared to using either `ConfigDict` or `namedtuple`s. Examples of such use cases include training script configuration, and returning multiple named values. Args: **kwargs: key-value pairs to be stored in the `ConfigDict`. Returns: A `ConfigDict` containing the key-value pairs in `kwargs`. """ return ConfigDict(initial_dictionary=kwargs) # TODO(sergomez): make placeholders required by default. def placeholder(field_type, required=False): """Defines an entry in a ConfigDict that has no value yet. Example: config = configdict.create( batch_size = configdict.placeholder(int), frame_shape = configdict.placeholder(tf.TensorShape)) Args: field_type: type of value. required: If True, the placeholder will raise an error on access if the underlying value hasn't been set. Returns: A `FieldReference` with value None and the given type. """ return FieldReference(None, field_type=field_type, required=required) def required_placeholder(field_type): """Defines an entry in a ConfigDict with unknown but required value. Example: config = configdict.create( batch_size = configdict.required_placeholder(int)) try: print(config.batch_size) except RequiredValueError: pass config.batch_size = 10 print(config.batch_size) # 10 Args: field_type: type of value. Returns: A `FieldReference` with value None and the given type. """ return placeholder(field_type, required=True) def recursive_rename(conf, old_name, new_name): """Returns copy of conf with old_name recursively replaced by new_name. This is not done in place, no changes are made to conf but a new ConfigDict is returned with the changes made. This is useful if the name of a parameter has been changed in code but you need to load an old config. Example usage: updated_conf = configdict.recursive_rename(conf, "config", "kwargs") Args: conf: a ConfigDict which needs updating old_name: the name used in the ConfigDict which is out of sync with the code new_name: the name used in the code Returns: A ConfigDict which is a copy of conf but with all instances of old_name replaced with new_name. """ new_conf = ConfigDict() for name, c in conf.items(): if isinstance(c, ConfigDict): new_c = recursive_rename(c, old_name, new_name) else: new_c = c if name == old_name: setattr(new_conf, new_name, new_c) else: setattr(new_conf, name, new_c) return new_conf ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/ml_collections/config_dict/examples/0000755162426002575230000000000000000000000026570 5ustar00mohitreddyprimarygroup00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/examples/config.py0000644162426002575230000001043300000000000030410 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of a config file using ConfigDict. The idea of this configuration file is to show a typical use case of ConfigDict, as well as its limitations. This also exemplifies a self-referencing ConfigDict. """ import copy import ml_collections def _get_flat_config(): """Helper to generate simple config without references.""" # The suggested way to create a ConfigDict() is to call its constructor # and assign all relevant fields. config = ml_collections.ConfigDict() # In order to add new attributes you can just use . notation, like with any # python object. They will be tracked by ConfigDict, and you get type checking # etc. for free. config.integer = 23 config.float = 2.34 config.string = 'james' config.bool = True # It is possible to assign dictionaries to ConfigDict and they will be # automatically and recursively wrapped with ConfigDict. However, make sure # that the dict you are assigning does not use internal references/cycles as # this is not supported. Instead, create the dicts explicitly as demonstrated # by get_config(). But note that this operation makes an element-by-element # copy of your original dict. # Also note that the recursive wrapping on input dictionaries with ConfigDict # does not extend through non-dictionary types (including basic Python types # and custom classes). This causes unexpected behavior most commonly if a # value is a list of dictionaries, so avoid giving ConfigDict such inputs. config.dict = { 'integer': 1, 'float': 3.14, 'string': 'mark', 'bool': False, 'dict': { 'float': 5 } } return config def get_config(): """Returns a ConfigDict instance describing a complex config. Returns: A ConfigDict instance with the structure: ``` CONFIG-+-- integer |-- float |-- string |-- bool |-- dict +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- float | |-- object +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- float | |-- object_copy +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- integer | |-- float | |-- string | |-- bool | |-- dict +-- float | |-- object_reference [reference pointing to CONFIG-+--object] ``` """ config = _get_flat_config() config.object = _get_flat_config() # References work just fine, so you will be able to override both # values at the same time. The rule is the same as for python objects, # everything that is mutable is passed as a reference, thus it will not work # with assigning integers or strings, but will work just fine with # ConfigDicts. # WARNING: Each time you assign a dictionary as a value it will create a new # instance of ConfigDict in memory, thus it will be a copy of the original # dict and not a reference to the original. config.object_reference = config.object # ConfigDict supports deepcopying. config.object_copy = copy.deepcopy(config.object) return config ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/examples/config_dict_advanced.py0000644162426002575230000000771500000000000033251 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of ConfigDict usage. This example includes loading a ConfigDict in FLAGS, locking it, type safety, iteration over fields, checking for a particular field, unpacking with `**`, and loading dictionary from string representation. """ from absl import app from absl import flags from ml_collections.config_flags import config_flags import yaml FLAGS = flags.FLAGS config_flags.DEFINE_config_file( 'my_config', default='ml_collections/config_dict/examples/config.py') def dummy_function(string, **unused_kwargs): return 'Hello {}'.format(string) def print_section(name): print() print() print('-' * len(name)) print(name.upper()) print('-' * len(name)) print() def main(_): # Config is already loaded in FLAGS.my_config due to the logic hidden # in app.run(). config = FLAGS.my_config print_section('Printing config.') print(FLAGS.my_config) # Config is of our type ConfigDict. print('Type of the config {}'.format(type(FLAGS.my_config))) # By default it is locked, thus you cannot add new fields. # This prevents you from misspelling your attribute name. print_section('Locking.') print('config.is_locked={}'.format(config.is_locked)) try: config.object.new_field = -3 except AttributeError as e: print(e) # There is also "did you mean" feature! try: config.object.floet = -3. except AttributeError as e: print(e) # However if you want to modify it you can always unlock. print_section('Unlocking.') with config.unlocked(): config.object.new_field = -3 print('config.object.new_field={}'.format(config.object.new_field)) # By default config is also type-safe, so you cannot change the type of any # field. print_section('Type safety.') try: config.float = 'jerry' except TypeError as e: print(e) config.float = -1.2 print('config.float={}'.format(config.float)) # NoneType is ignored by type safety and can both override and be overridden. config.float = None config.float = -1.2 # You can temporarly turn type safety off. with config.ignore_type(): config.float = 'tom' print('config.float={}'.format(config.float)) config.float = 2.3 print('config.float={}'.format(config.float)) # You can use ConfigDict as a regular dict in many typical use-cases: # Iteration over fields: print_section('Iteration over fields.') for field in config: print('config has field "{}"'.format(field)) # Checking if it contains a particular field using the "in" command. print_section('Checking for a particular field.') for field in ('float', 'non_existing'): if field in config: print('"{}" is in config'.format(field)) else: print('"{}" is not in config'.format(field)) # Using ** unrolling to pass the config to a function as named arguments. print_section('Unpacking with **') print(dummy_function(**config)) # You can even load a dictionary (notice it is not ConfigDict anymore) from # a yaml string representation of ConfigDict. # Note: __repr__ (not __str__) is the recommended representation, as it # preserves FieldReferences and placeholders. print_section('Loading dictionary from string representation.') dictionary = yaml.load(repr(config)) print('dict["object_reference"]["dict"]["dict"]["float"]={}'.format( dictionary['object_reference']['dict']['dict']['float'])) if __name__ == '__main__': app.run(main) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/examples/config_dict_basic.py0000644162426002575230000000270500000000000032557 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of basic ConfigDict usage. This example shows the most basic usage of ConfigDict, including type safety. For examples of more features, see example_advanced. """ from absl import app import ml_collections def main(_): cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg.integer_field) # Prints 123. print(cfg['integer_field']) # Prints 123 as well. try: cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. except TypeError as e: print(e) cfg.float_field = 12 # Works: `int` types can be assigned to `float`. cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. print(cfg) if __name__ == '__main__': app.run(main) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/examples/config_dict_initialization.py0000644162426002575230000000630300000000000034523 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of initialization features and gotchas in a ConfigDict. """ import copy from absl import app import ml_collections def print_section(name): print() print() print('-' * len(name)) print(name.upper()) print('-' * len(name)) print() def main(_): inner_dict = {'list': [1, 2], 'tuple': (1, 2, [3, 4], (5, 6))} example_dict = { 'string': 'tom', 'int': 2, 'list': [1, 2], 'set': {1, 2}, 'tuple': (1, 2), 'ref': ml_collections.FieldReference({'int': 0}), 'inner_dict_1': inner_dict, 'inner_dict_2': inner_dict } print_section('Initializing on dictionary.') # ConfigDict can be initialized on example_dict example_cd = ml_collections.ConfigDict(example_dict) # Dictionary fields are also converted to ConfigDict print(type(example_cd.inner_dict_1)) # And the reference structure is preserved print(id(example_cd.inner_dict_1) == id(example_cd.inner_dict_2)) print_section('Initializing on ConfigDict.') # ConfigDict can also be initialized on a ConfigDict example_cd_cd = ml_collections.ConfigDict(example_cd) # Yielding the same result: print(example_cd == example_cd_cd) # Note that the memory addresses are different print(id(example_cd) == id(example_cd_cd)) # The memory addresses of the attributes are not the same because of the # FieldReference, which gets removed on the second initialization list_to_ids = lambda x: [id(element) for element in x] print( set(list_to_ids(list(example_cd.values()))) == set( list_to_ids(list(example_cd_cd.values())))) print_section('Initializing on self-referencing dictionary.') # Initialization works on a self-referencing dict self_ref_dict = copy.deepcopy(example_dict) self_ref_dict['self'] = self_ref_dict self_ref_cd = ml_collections.ConfigDict(self_ref_dict) # And the reference structure is replicated print(id(self_ref_cd) == id(self_ref_cd.self)) print_section('Unexpected initialization behavior.') # ConfigDict initialization doesn't look inside lists, so doesn't convert a # dict in a list to ConfigDict dict_in_list_in_dict = {'list': [{'troublemaker': 0}]} dict_in_list_in_dict_cd = ml_collections.ConfigDict(dict_in_list_in_dict) print(type(dict_in_list_in_dict_cd.list[0])) # This can cause the reference structure to not be replicated referred_dict = {'key': 'value'} bad_reference = {'referred_dict': referred_dict, 'list': [referred_dict]} bad_reference_cd = ml_collections.ConfigDict(bad_reference) print(id(bad_reference_cd.referred_dict) == id(bad_reference_cd.list[0])) if __name__ == '__main__': app.run() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/examples/config_dict_lock.py0000644162426002575230000000245700000000000032432 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of ConfigDict usage of lock. This example shows the roles and scopes of ConfigDict's lock(). """ from absl import app import ml_collections def main(_): cfg = ml_collections.ConfigDict() cfg.integer_field = 123 # Locking prohibits the addition and deletion of new fields but allows # modification of existing values. Locking happens automatically during # loading through flags. cfg.lock() try: cfg.intagar_field = 124 # Raises AttributeError and suggests valid field. except AttributeError as e: print(e) cfg.integer_field = -123 # Works fine. with cfg.unlocked(): cfg.intagar_field = 1555 # Works fine too. print(cfg) if __name__ == '__main__': app.run() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/examples/config_dict_placeholder.py0000644162426002575230000000335000000000000033755 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of placeholder fields in a ConfigDict. This example shows how ConfigDict placeholder fields work. For a more complete example of ConfigDict features, see example_advanced. """ from absl import app import ml_collections def main(_): placeholder = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict() cfg.placeholder = placeholder cfg.optional = ml_collections.FieldReference(0, field_type=int) cfg.nested = ml_collections.ConfigDict() cfg.nested.placeholder = placeholder try: cfg.optional = 'tom' # Raises Type error as this field is an integer. except TypeError as e: print(e) cfg.optional = 1555 # Works fine. cfg.placeholder = 1 # Changes the value of both placeholder and # nested.placeholder fields. # Note that the indirection provided by FieldReferences will be lost if # accessed through a ConfigDict: placeholder = ml_collections.FieldReference(0) cfg.field1 = placeholder cfg.field2 = placeholder # This field will be tied to cfg.field1. cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. print(cfg) if __name__ == '__main__': app.run() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/examples/examples_test.py0000644162426002575230000000340200000000000032016 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for ConfigDict examples. Ensures that config_dict_basic, config_dict_initialization, config_dict_lock, config_dict_placeholder, field_reference, frozen_config_dict run successfully. """ from absl.testing import absltest from absl.testing import parameterized from ml_collections.config_dict.examples import config_dict_advanced from ml_collections.config_dict.examples import config_dict_basic from ml_collections.config_dict.examples import config_dict_initialization from ml_collections.config_dict.examples import config_dict_lock from ml_collections.config_dict.examples import config_dict_placeholder from ml_collections.config_dict.examples import field_reference from ml_collections.config_dict.examples import frozen_config_dict class ConfigDictExamplesTest(parameterized.TestCase): @parameterized.parameters(config_dict_advanced, config_dict_basic, config_dict_initialization, config_dict_lock, config_dict_placeholder, field_reference, frozen_config_dict) def testScriptRuns(self, example_name): example_name.main(None) if __name__ == '__main__': absltest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/examples/field_reference.py0000644162426002575230000001132400000000000032244 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of FieldReference usage. This shows how to use FieldReferences for lazy computation. """ from absl import app import ml_collections from ml_collections.config_dict import config_dict def lazy_computation(): """Simple example of lazy computation with `configdict.FieldReference`.""" ref = ml_collections.FieldReference(1) print(ref.get()) # Prints 1 add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 # Addition is lazily computed for FieldReferences so changing ref will change # the value that is used to compute add_ten. ref.set(5) print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 def change_lazy_computation(): """Overriding lazily computed values.""" config = ml_collections.ConfigDict() config.reference = 1 config.reference_0 = config.get_ref('reference') + 10 config.reference_1 = config.get_ref('reference') + 20 config.reference_1_0 = config.get_ref('reference_1') + 100 print(config.reference) # Prints 1. print(config.reference_0) # Prints 11. print(config.reference_1) # Prints 21. print(config.reference_1_0) # Prints 121. config.reference_1 = 30 print(config.reference) # Prints 1 (unchanged). print(config.reference_0) # Prints 11 (unchanged). print(config.reference_1) # Prints 30. print(config.reference_1_0) # Prints 130. def create_cycle(): """Creates a cycle within a ConfigDict.""" config = ml_collections.ConfigDict() config.integer_field = 1 config.bigger_integer_field = config.get_ref('integer_field') + 10 try: # Raises a MutabilityError because setting config.integer_field would # cause a cycle. config.integer_field = config.get_ref('bigger_integer_field') + 2 except config_dict.MutabilityError as e: print(e) def lazy_configdict(): """Example usage of lazy computation with ConfigDict.""" config = ml_collections.ConfigDict() config.reference_field = ml_collections.FieldReference(1) config.integer_field = 2 config.float_field = 2.5 # No lazy evaluatuations because we didn't use get_ref() config.no_lazy = config.integer_field * config.float_field # This will lazily evaluate ONLY config.integer_field config.lazy_integer = config.get_ref('integer_field') * config.float_field # This will lazily evaluate ONLY config.float_field config.lazy_float = config.integer_field * config.get_ref('float_field') # This will lazily evaluate BOTH config.integer_field and config.float_Field config.lazy_both = (config.get_ref('integer_field') * config.get_ref('float_field')) config.integer_field = 3 print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value print(config.lazy_integer) # Prints 7.5 config.float_field = 3.5 print(config.lazy_float) # Prints 7.0 print(config.lazy_both) # Prints 10.5 def lazy_configdict_advanced(): """Advanced lazy computation with ConfigDict.""" # FieldReferences can be used with ConfigDict as well config = ml_collections.ConfigDict() config.float_field = 12.6 config.integer_field = 123 config.list_field = [0, 1, 2] config.float_multiply_field = config.get_ref('float_field') * 3 print(config.float_multiply_field) # Prints 37.8 config.float_field = 10.0 print(config.float_multiply_field) # Prints 30.0 config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] config.list_field = [-1] print(config.longer_list_field) # Prints [-1, 3, 4, 5] # Both operands can be references config.ref_subtraction = ( config.get_ref('float_field') - config.get_ref('integer_field')) print(config.ref_subtraction) # Prints -113.0 config.integer_field = 10 print(config.ref_subtraction) # Prints 0.0 def main(argv=()): del argv # Unused. lazy_computation() lazy_configdict() change_lazy_computation() create_cycle() lazy_configdict_advanced() if __name__ == '__main__': app.run() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/examples/frozen_config_dict.py0000644162426002575230000000576700000000000033014 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Example of basic FrozenConfigDict usage. This example shows the most basic usage of FrozenConfigDict, highlighting the differences between FrozenConfigDict and ConfigDict and including converting between the two. """ from absl import app import ml_collections def print_section(name): print() print() print('-' * len(name)) print(name.upper()) print('-' * len(name)) print() def main(_): print_section('Attribute Types.') cfg = ml_collections.ConfigDict() cfg.int = 1 cfg.list = [1, 2, 3] cfg.tuple = (1, 2, 3) cfg.set = {1, 2, 3} cfg.frozenset = frozenset({1, 2, 3}) cfg.dict = { 'nested_int': 4, 'nested_list': [4, 5, 6], 'nested_tuple': ([4], 5, 6), } print('Types of cfg fields:') print('list: ', type(cfg.list)) # List print('set: ', type(cfg.set)) # Set print('nested_list: ', type(cfg.dict.nested_list)) # List print('nested_tuple[0]: ', type(cfg.dict.nested_tuple[0])) # List frozen_cfg = ml_collections.FrozenConfigDict(cfg) print('\nTypes of FrozenConfigDict(cfg) fields:') print('list: ', type(frozen_cfg.list)) # Tuple print('set: ', type(frozen_cfg.set)) # Frozenset print('nested_list: ', type(frozen_cfg.dict.nested_list)) # Tuple print('nested_tuple[0]: ', type(frozen_cfg.dict.nested_tuple[0])) # Tuple cfg_from_frozen = ml_collections.ConfigDict(frozen_cfg) print('\nTypes of ConfigDict(FrozenConfigDict(cfg)) fields:') print('list: ', type(cfg_from_frozen.list)) # List print('set: ', type(cfg_from_frozen.set)) # Set print('nested_list: ', type(cfg_from_frozen.dict.nested_list)) # List print('nested_tuple[0]: ', type(cfg_from_frozen.dict.nested_tuple[0])) # List print('\nCan use FrozenConfigDict.as_configdict() to convert to ConfigDict:') print(cfg_from_frozen == frozen_cfg.as_configdict()) # True print_section('Immutability.') try: frozen_cfg.new_field = 1 # Raises AttributeError because of immutability. except AttributeError as e: print(e) print_section('"==" and eq_as_configdict().') # FrozenConfigDict.__eq__() is not type-invariant with respect to ConfigDict print(frozen_cfg == cfg) # False # FrozenConfigDict.eq_as_configdict() is type-invariant with respect to # ConfigDict print(frozen_cfg.eq_as_configdict(cfg)) # True # .eq_as_congfigdict() is also a method of ConfigDict print(cfg.eq_as_configdict(frozen_cfg)) # True if __name__ == '__main__': app.run(main) ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/ml_collections/config_dict/tests/0000755162426002575230000000000000000000000026114 5ustar00mohitreddyprimarygroup00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/tests/config_dict_test.py0000644162426002575230000013431000000000000031777 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for ml_collections.ConfigDict.""" import abc import collections as python_collections import json import pickle import sys from absl.testing import absltest from absl.testing import parameterized import ml_collections from ml_collections.config_dict import config_dict import mock import six import yaml _TEST_FIELD = {'int': 0} _TEST_DICT = { 'float': 2.34, 'string': 'tom', 'int': 2, 'list': [1, 2], 'dict': { 'float': -1.23, 'int': 23 }, } def _test_function(): pass # Having ABCMeta as a metaclass shouldn't break yaml serialization. class _TestClass(six.with_metaclass(abc.ABCMeta, object)): def __init__(self): self.variable_1 = 1 self.variable_2 = '2' _test_object = _TestClass() class _TestClassNoStr(): # pylint: disable=old-style-class pass _TEST_DICT_BEST_EFFORT = dict(_TEST_DICT) _TEST_DICT_BEST_EFFORT.update({ 'unserializable': _TestClass, 'unserializable_no_str': _TestClassNoStr, 'function': _test_function, 'object': _test_object, 'set': {1, 2, 3} }) # This is how we expect the _TEST_DICT to look after we change the name float to # double using the function configdict.recursive_rename _TEST_DICT_CHANGE_FLOAT_NAME = { 'double': 2.34, 'string': 'tom', 'int': 2, 'list': [1, 2], 'dict': { 'double': -1.23, 'int': 23 }, } def _get_test_dict(): test_dict = dict(_TEST_DICT) field = ml_collections.FieldReference(_TEST_FIELD) test_dict['ref'] = field test_dict['ref2'] = field return test_dict def _get_test_dict_best_effort(): test_dict = dict(_TEST_DICT_BEST_EFFORT) field = ml_collections.FieldReference(_TEST_FIELD) test_dict['ref'] = field test_dict['ref2'] = field return test_dict def _get_test_config_dict(): return ml_collections.ConfigDict(_get_test_dict()) def _get_test_config_dict_best_effort(): return ml_collections.ConfigDict(_get_test_dict_best_effort()) _JSON_TEST_DICT = ('{"dict": {"float": -1.23, "int": 23},' ' "float": 2.34,' ' "int": 2,' ' "list": [1, 2],' ' "ref": {"int": 0},' ' "ref2": {"int": 0},' ' "string": "tom"}') if six.PY2: _DICT_TYPE = "!!python/name:__builtin__.dict ''" _UNSERIALIZABLE_MSG = "unserializable object of type: " else: _DICT_TYPE = "!!python/name:builtins.dict ''" _UNSERIALIZABLE_MSG = ( "unserializable object: ") _TYPES = { 'dict_type': _DICT_TYPE, 'configdict_type': '!!python/object:ml_collections.config_dict.config_dict' '.ConfigDict', 'fieldreference_type': '!!python/object:ml_collections.config_dict' '.config_dict.FieldReference' } _JSON_BEST_EFFORT_TEST_DICT = ( '{"dict": {"float": -1.23, "int": 23},' ' "float": 2.34,' ' "function": "function _test_function",' ' "int": 2,' ' "list": [1, 2],' ' "object": {"variable_1": 1, "variable_2": "2"},' ' "ref": {"int": 0},' ' "ref2": {"int": 0},' ' "set": [1, 2, 3],' ' "string": "tom",' ' "unserializable": "unserializable object: ' '",' ' "unserializable_no_str": "%s"}') % _UNSERIALIZABLE_MSG _REPR_TEST_DICT = """ dict: float: -1.23 int: 23 float: 2.34 int: 2 list: - 1 - 2 ref: &id001 {fieldreference_type} _field_type: {dict_type} _ops: [] _required: false _value: int: 0 ref2: *id001 string: tom """.format(**_TYPES) _STR_TEST_DICT = """ dict: {float: -1.23, int: 23} float: 2.34 int: 2 list: [1, 2] ref: &id001 {int: 0} ref2: *id001 string: tom """ _STR_NESTED_TEST_DICT = """ dict: {float: -1.23, int: 23} float: 2.34 int: 2 list: [1, 2] nested_dict: float: -1.23 int: 23 nested_dict: float: -1.23 int: 23 non_nested_dict: {float: -1.23, int: 23} nested_list: - 1 - 2 - [3, 4, 5] - 6 ref: &id001 {int: 0} ref2: *id001 string: tom """ class ConfigDictTest(parameterized.TestCase): """Tests ConfigDict in config flags library.""" def assertEqualConfigs(self, cfg, dictionary): """Asserts recursive equality of config and a dictionary.""" self.assertEqual(cfg.to_dict(), dictionary) def testCreating(self): """Tests basic config creation.""" cfg = ml_collections.ConfigDict() cfg.field = 2.34 self.assertEqual(cfg.field, 2.34) def testDir(self): """Test that dir() works correctly on config.""" cfg = ml_collections.ConfigDict() cfg.field = 2.34 self.assertIn('field', dir(cfg)) self.assertIn('lock', dir(cfg)) def testFromDictConstruction(self): """Tests creation of config from existing dictionary.""" cfg = ml_collections.ConfigDict(_TEST_DICT) self.assertEqualConfigs(cfg, _TEST_DICT) def testOverridingValues(self): """Tests basic values overriding.""" cfg = ml_collections.ConfigDict() cfg.field = 2.34 self.assertEqual(cfg.field, 2.34) cfg.field = -2.34 self.assertEqual(cfg.field, -2.34) def testDictAttributeTurnsIntoConfigDict(self): """Tests that dicts in a ConfigDict turn to ConfigDicts (recursively).""" cfg = ml_collections.ConfigDict(_TEST_DICT) # Test conversion to dict on creation. self.assertIsInstance(cfg.dict, ml_collections.ConfigDict) # Test conversion to dict on setting attribute. new_dict = {'inside_dict': {'inside_key': 0}} cfg.new_dict = new_dict self.assertIsInstance(cfg.new_dict, ml_collections.ConfigDict) self.assertIsInstance(cfg.new_dict.inside_dict, ml_collections.ConfigDict) self.assertEqual(cfg.new_dict.to_dict(), new_dict) def testOverrideExceptions(self): """Test the `int` and unicode-string exceptions to overriding. ConfigDict forces strong type-checking with two exceptions. The first is that `int` values can be stored to fields of type `float`. And secondly, all string types can be stored in fields of type `str` or `unicode`. """ cfg = ml_collections.ConfigDict() # Test that overriding 'float' fields with int works. cfg.float_field = 2.34 cfg.float_field = 2 self.assertEqual(cfg.float_field, 2.0) # Test that overriding with Unicode strings works. cfg.string_field = '42' cfg.string_field = u'42' self.assertEqual(cfg.string_field, '42') # Test that overriding a Unicode field with a `str` type works. cfg.unicode_string_field = u'42' cfg.unicode_string_field = '42' self.assertEqual(cfg.unicode_string_field, u'42') # Test that overriding a list with a tuple works. cfg.tuple_field = [1, 2, 3] cfg.tuple_field = (1, 2) self.assertEqual(cfg.tuple_field, [1, 2]) # Test that overriding a tuple with a list works. cfg.list_field = [23, 42] cfg.list_field = (8, 9, 10) self.assertEqual(cfg.list_field, [8, 9, 10]) # Test that int <-> long conversions work. int_value = 1 # In Python 2, int(very large number) returns a long long_value = int(1e100) cfg.int_field = int_value cfg.int_field = long_value self.assertEqual(cfg.int_field, long_value) if sys.version_info.major == 2: expected = long else: expected = int self.assertIsInstance(cfg.int_field, expected) cfg.long_field = long_value cfg.long_field = int_value self.assertEqual(cfg.long_field, int_value) self.assertIsInstance(cfg.long_field, expected) def testOverrideFieldReference(self): """Test overriding with FieldReference objects.""" cfg = ml_collections.ConfigDict() cfg.field_1 = 'field_1' cfg.field_2 = 'field_2' # Override using a FieldReference. cfg.field_1 = ml_collections.FieldReference('override_1') # Override FieldReference field using another FieldReference. cfg.field_1 = ml_collections.FieldReference('override_2') # Override using empty FieldReference. cfg.field_2 = ml_collections.FieldReference(None, field_type=str) # Override FieldReference field using string. cfg.field_2 = 'field_2' # Check a TypeError is raised when using FieldReference's with wrong type. with self.assertRaises(TypeError): cfg.field_2 = ml_collections.FieldReference(1) with self.assertRaises(TypeError): cfg.field_2 = ml_collections.FieldReference(None, field_type=int) def testTypeSafe(self): """Tests type safe checking.""" cfg = _get_test_config_dict() with self.assertRaisesRegex(TypeError, 'field \'float\''): cfg.float = 'tom' # Test that float cannot be assigned to int. with self.assertRaisesRegex(TypeError, 'field \'int\''): cfg.int = 12.8 with self.assertRaisesRegex(TypeError, 'field \'string\''): cfg.string = -123 with self.assertRaisesRegex(TypeError, 'field \'float\''): cfg.dict.float = 'string' # Ensure None is ignored by type safety cfg.string = None cfg.string = 'tom' def testIgnoreType(self): cfg = ml_collections.ConfigDict({ 'string': 'This is a string', 'float': 3.0, 'list': [ml_collections.ConfigDict({'float': 1.0})], 'tuple': [ml_collections.ConfigDict({'float': 1.0})], 'dict': { 'float': 1.0 } }) with cfg.ignore_type(): cfg.string = -123 cfg.float = 'string' cfg.list[0].float = 'string' cfg.tuple[0].float = 'string' cfg.dict.float = 'string' def testTypeUnsafe(self): """Tests lack of type safe checking.""" cfg = ml_collections.ConfigDict(_get_test_dict(), type_safe=False) cfg.float = 'tom' cfg.string = -123 cfg.int = 12.8 def testLocking(self): """Tests lock mechanism.""" cfg = ml_collections.ConfigDict() cfg.field = 2 cfg.dict_field = {'float': 1.23, 'integer': 3} cfg.ref = ml_collections.FieldReference( ml_collections.ConfigDict({'integer': 0})) cfg.lock() cfg.field = -4 with self.assertRaises(AttributeError): cfg.new_field = 2 with self.assertRaises(AttributeError): cfg.dict_field.new_field = -1.23 with self.assertRaises(AttributeError): cfg.ref.new_field = 1 with self.assertRaises(AttributeError): del cfg.field def testUnlocking(self): """Tests unlock mechanism.""" cfg = ml_collections.ConfigDict() cfg.dict_field = {'float': 1.23, 'integer': 3} cfg.ref = ml_collections.FieldReference( ml_collections.ConfigDict({'integer': 0})) cfg.lock() with cfg.unlocked(): cfg.new_field = 2 cfg.dict_field.new_field = -1.23 cfg.ref.new_field = 1 def testGetMethod(self): """Tests get method.""" cfg = _get_test_config_dict() self.assertEqual(cfg.get('float', -1), cfg.float) self.assertEqual(cfg.get('ref', -1), cfg.ref) self.assertEqual(cfg.get('another_key', -1), -1) self.assertIsNone(cfg.get('another_key')) def testItemsMethod(self): """Tests items method.""" cfg = _get_test_config_dict() self.assertEqual(dict(**cfg), dict(cfg.items())) items = cfg.items() self.assertEqual(len(items), len(_get_test_dict())) for entry in _TEST_DICT.items(): if isinstance(entry[1], dict): entry = (entry[0], ml_collections.ConfigDict(entry[1])) self.assertIn(entry, items) self.assertIn(('ref', cfg.ref), items) self.assertIn(('ref2', cfg.ref2), items) ind_ref = items.index(('ref', cfg.ref)) ind_ref2 = items.index(('ref2', cfg.ref2)) self.assertIs(items[ind_ref][1], items[ind_ref2][1]) cfg = ml_collections.ConfigDict() self.assertEqual(dict(**cfg), dict(cfg.items())) # Test that items are sorted self.assertEqual(sorted(dict(**cfg).items()), cfg.items()) def testGetItemRecursively(self): """Tests getting items recursively (e.g., config['a.b']).""" cfg = _get_test_config_dict() self.assertEqual(cfg['dict.float'], -1.23) self.assertEqual('%(dict.int)i' % cfg, '23') def testIterItemsMethod(self): """Tests iteritems method.""" cfg = _get_test_config_dict() self.assertEqual(dict(**cfg), dict(cfg.iteritems())) cfg = ml_collections.ConfigDict() self.assertEqual(dict(**cfg), dict(cfg.iteritems())) def testIterKeysMethod(self): """Tests iterkeys method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertEqual(set(six.iterkeys(some_dict)), set(six.iterkeys(cfg))) # Test that keys are sorted for k_ref, k in zip(sorted(six.iterkeys(cfg)), six.iterkeys(cfg)): self.assertEqual(k_ref, k) def testKeysMethod(self): """Tests keys method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertEqual(set(some_dict.keys()), set(cfg.keys())) # Test that keys are sorted for k_ref, k in zip(sorted(cfg.keys()), cfg.keys()): self.assertEqual(k_ref, k) def testLenMethod(self): """Tests keys method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertLen(cfg, len(some_dict)) def testIterValuesMethod(self): """Tests itervalues method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertEqual(set(six.itervalues(some_dict)), set(six.itervalues(cfg))) # Test that items are sorted for k_ref, v in zip(sorted(six.iterkeys(cfg)), six.itervalues(cfg)): self.assertEqual(cfg[k_ref], v) def testValuesMethod(self): """Tests values method.""" some_dict = {'x1': 32, 'x2': 5.2, 'x3': 'str'} cfg = ml_collections.ConfigDict(some_dict) self.assertEqual(set(some_dict.values()), set(cfg.values())) # Test that items are sorted for k_ref, v in zip(sorted(cfg.keys()), cfg.values()): self.assertEqual(cfg[k_ref], v) def testIterValuesResolvesReferences(self): """Tests itervalues FieldReference resolution.""" cfg = ml_collections.ConfigDict({'x1': 32, 'x2': 5.2, 'x3': 'str'}) ref = ml_collections.FieldReference(0) cfg['x4'] = ref for v in cfg.itervalues(): self.assertNotIsInstance(v, ml_collections.FieldReference) self.assertIn(ref, cfg.itervalues(preserve_field_references=True)) def testValuesResolvesReferences(self): """Tests values FieldReference resolution.""" cfg = ml_collections.ConfigDict({'x1': 32, 'x2': 5.2, 'x3': 'str'}) ref = ml_collections.FieldReference(0) cfg['x4'] = ref for v in cfg.values(): self.assertNotIsInstance(v, ml_collections.FieldReference) self.assertIn(ref, cfg.values(preserve_field_references=True)) def testIterItemsResolvesReferences(self): """Tests iteritems FieldReference resolution.""" cfg = ml_collections.ConfigDict({'x1': 32, 'x2': 5.2, 'x3': 'str'}) ref = ml_collections.FieldReference(0) cfg['x4'] = ref for _, v in cfg.iteritems(): self.assertNotIsInstance(v, ml_collections.FieldReference) self.assertIn(('x4', ref), cfg.iteritems(preserve_field_references=True)) def testItemsResolvesReferences(self): """Tests items FieldReference resolution.""" cfg = ml_collections.ConfigDict({'x1': 32, 'x2': 5.2, 'x3': 'str'}) ref = ml_collections.FieldReference(0) cfg['x4'] = ref for _, v in cfg.items(): self.assertNotIsInstance(v, ml_collections.FieldReference) self.assertIn(('x4', ref), cfg.items(preserve_field_references=True)) def testEquals(self): """Tests __eq__ and __ne__ methods.""" some_dict = { 'float': 1.23, 'integer': 3, 'list': [1, 2], 'dict': { 'a': {}, 'b': 'string' } } cfg = ml_collections.ConfigDict(some_dict) cfg_other = ml_collections.ConfigDict(some_dict) self.assertEqual(cfg, cfg_other) self.assertEqual(ml_collections.ConfigDict(cfg), cfg_other) cfg_other.float = 3 self.assertNotEqual(cfg, cfg_other) cfg_other.float = cfg.float cfg_other.list = ['a', 'b'] self.assertNotEqual(cfg, cfg_other) cfg_other.list = cfg.list cfg_other.lock() self.assertNotEqual(cfg, cfg_other) cfg_other.unlock() cfg_other = ml_collections.ConfigDict(some_dict, type_safe=False) self.assertNotEqual(cfg, cfg_other) cfg = ml_collections.ConfigDict(some_dict) # References that have the same id should be equal (even if self-references) cfg_other = ml_collections.ConfigDict(some_dict) cfg_other.me = cfg cfg.me = cfg self.assertEqual(cfg, cfg_other) cfg = ml_collections.ConfigDict(some_dict) cfg.me = cfg self.assertEqual(cfg, cfg) # Self-references that do not have the same id loop infinitely cfg_other = ml_collections.ConfigDict(some_dict) cfg_other.me = cfg_other # Temporarily disable coverage trace while testing runtime is exceeded trace_func = sys.gettrace() sys.settrace(None) with self.assertRaises(RuntimeError): cfg == cfg_other # pylint:disable=pointless-statement sys.settrace(trace_func) def testEqAsConfigDict(self): """Tests .eq_as_configdict() method.""" cfg_1 = _get_test_config_dict() cfg_2 = _get_test_config_dict() cfg_2.added_field = 3.14159 cfg_self_ref = _get_test_config_dict() cfg_self_ref.self_ref = cfg_self_ref frozen_cfg_1 = ml_collections.FrozenConfigDict(cfg_1) frozen_cfg_2 = ml_collections.FrozenConfigDict(cfg_2) self.assertTrue(cfg_1.eq_as_configdict(cfg_1)) self.assertTrue(cfg_1.eq_as_configdict(frozen_cfg_1)) self.assertTrue(frozen_cfg_1.eq_as_configdict(cfg_1)) self.assertTrue(frozen_cfg_1.eq_as_configdict(frozen_cfg_1)) self.assertFalse(cfg_1.eq_as_configdict(cfg_2)) self.assertFalse(cfg_1.eq_as_configdict(frozen_cfg_2)) self.assertFalse(frozen_cfg_1.eq_as_configdict(cfg_self_ref)) self.assertFalse(frozen_cfg_1.eq_as_configdict(frozen_cfg_2)) self.assertFalse(cfg_self_ref.eq_as_configdict(cfg_1)) def testHash(self): some_dict = {'float': 1.23, 'integer': 3} cfg = ml_collections.ConfigDict(some_dict) with self.assertRaisesRegex(TypeError, 'unhashable type'): hash(cfg) # Ensure Python realizes ConfigDict is not hashable. self.assertNotIsInstance(cfg, python_collections.Hashable) def testDidYouMeanFeature(self): """Tests 'did you mean' suggestions.""" cfg = ml_collections.ConfigDict() cfg.learning_rate = 0.01 cfg.lock() with self.assertRaisesRegex(AttributeError, 'Did you mean.*learning_rate.*'): _ = cfg.laerning_rate with cfg.unlocked(): with self.assertRaisesRegex(AttributeError, 'Did you mean.*learning_rate.*'): del cfg.laerning_rate with self.assertRaisesRegex(AttributeError, 'Did you mean.*learning_rate.*'): cfg.laerning_rate = 0.02 self.assertEqual(cfg.learning_rate, 0.01) with self.assertRaises(AttributeError): _ = self.laerning_rate def testReferences(self): """Tests assigning references in the dict.""" cfg = _get_test_config_dict() cfg.dict_ref = cfg.dict self.assertEqual(cfg.dict_ref, cfg.dict) def testPreserveReferences(self): """Tests that initializing with another ConfigDict preserves references.""" cfg = _get_test_config_dict() # In the original, "ref" and "ref2" are the same FieldReference self.assertIs(cfg.get_ref('ref'), cfg.get_ref('ref2')) # Create a copy from the original cfg2 = ml_collections.ConfigDict(cfg) # If the refs had not been preserved, get_ref would create a new # reference for each call self.assertIs(cfg2.get_ref('ref'), cfg2.get_ref('ref2')) self.assertIs(cfg2.ref, cfg2.ref2) # the values are also the same object def testUnpacking(self): """Tests ability to pass ConfigDict instance with ** operator.""" cfg = ml_collections.ConfigDict() cfg.x = 2 def foo(x): return x + 3 self.assertEqual(foo(**cfg), 5) def testUnpackingWithFieldReference(self): """Tests ability to pass ConfigDict instance with ** operator.""" cfg = ml_collections.ConfigDict() cfg.x = ml_collections.FieldReference(2) def foo(x): return x + 3 self.assertEqual(foo(**cfg), 5) def testReadingIncorrectField(self): """Tests whether accessing non-existing fields raises an exception.""" cfg = ml_collections.ConfigDict() with self.assertRaises(AttributeError): _ = cfg.non_existing_field with self.assertRaises(KeyError): _ = cfg['non_existing_field'] def testIteration(self): """Tests whether one can iterate over ConfigDict.""" cfg = ml_collections.ConfigDict() for i in range(10): cfg['field{}'.format(i)] = 'field{}'.format(i) for field in cfg: self.assertEqual(cfg[field], getattr(cfg, field)) def testDeYaml(self): """Tests YAML deserialization.""" cfg = _get_test_config_dict() deyamled = yaml.load(cfg.to_yaml()) self.assertEqual(cfg, deyamled) def testJSONConversion(self): """Tests JSON serialization.""" cfg = _get_test_config_dict() self.assertEqual( cfg.to_json(sort_keys=True).strip(), _JSON_TEST_DICT.strip()) cfg = _get_test_config_dict_best_effort() with self.assertRaises(TypeError): cfg.to_json() def testJSONConversionCustomEncoder(self): """Tests JSON serialization with custom encoder.""" cfg = _get_test_config_dict() encoder = json.JSONEncoder() mock_encoder_cls = mock.MagicMock() mock_encoder_cls.return_value = encoder with mock.patch.object(encoder, 'default') as mock_default: mock_default.return_value = '' cfg.to_json(json_encoder_cls=mock_encoder_cls) mock_default.assert_called() def testJSONConversionBestEffort(self): """Tests JSON serialization.""" # Check that best effort option doesn't break default functionality cfg = _get_test_config_dict() self.assertEqual( cfg.to_json_best_effort(sort_keys=True).strip(), _JSON_TEST_DICT.strip()) cfg_best_effort = _get_test_config_dict_best_effort() self.assertEqual( cfg_best_effort.to_json_best_effort(sort_keys=True).strip(), _JSON_BEST_EFFORT_TEST_DICT.strip()) def testReprConversion(self): """Tests repr conversion.""" cfg = _get_test_config_dict() self.assertEqual(repr(cfg).strip(), _REPR_TEST_DICT.strip()) def testLoadFromRepr(self): cfg_dict = ml_collections.ConfigDict() field = ml_collections.FieldReference(1) cfg_dict.r1 = field cfg_dict.r2 = field cfg_load = yaml.load(repr(cfg_dict)) # Test FieldReferences are preserved cfg_load['r1'].set(2) self.assertEqual(cfg_load['r1'].get(), cfg_load['r2'].get()) def testStrConversion(self): """Tests conversion to str.""" cfg = _get_test_config_dict() # Verify srt(cfg) doesn't raise errors. _ = str(cfg) test_dict_2 = _get_test_dict() test_dict_2['nested_dict'] = { 'float': -1.23, 'int': 23, 'nested_dict': { 'float': -1.23, 'int': 23, 'non_nested_dict': { 'float': -1.23, 'int': 233, }, }, 'nested_list': [1, 2, [3, 44, 5], 6], } cfg_2 = ml_collections.ConfigDict(test_dict_2) # Demonstrate that dot-access works. cfg_2.nested_dict.nested_dict.non_nested_dict.int = 23 cfg_2.nested_dict.nested_list[2][1] = 4 # Verify srt(cfg) doesn't raise errors. _ = str(cfg_2) def testDotInField(self): """Tests trying to create a dot containing field.""" cfg = ml_collections.ConfigDict() with self.assertRaises(ValueError): cfg['invalid.name'] = 2.3 def testToDictConversion(self): """Tests whether references are correctly handled when calling to_dict.""" cfg = ml_collections.ConfigDict() field = ml_collections.FieldReference('a string') cfg.dict = { 'float': 2.3, 'integer': 1, 'field_ref1': field, 'field_ref2': field } cfg.ref = cfg.dict cfg.self_ref = cfg pure_dict = cfg.to_dict() self.assertEqual(type(pure_dict), dict) self.assertIs(pure_dict, pure_dict['self_ref']) self.assertIs(pure_dict['dict'], pure_dict['ref']) # Ensure ConfigDict has been converted to dict. self.assertEqual(type(pure_dict['dict']), dict) # Ensure FieldReferences are not preserved, by default. self.assertNotIsInstance(pure_dict['dict']['field_ref1'], ml_collections.FieldReference) self.assertNotIsInstance(pure_dict['dict']['field_ref2'], ml_collections.FieldReference) self.assertEqual(pure_dict['dict']['field_ref1'], field.get()) self.assertEqual(pure_dict['dict']['field_ref2'], field.get()) pure_dict_with_refs = cfg.to_dict(preserve_field_references=True) self.assertEqual(type(pure_dict_with_refs), dict) self.assertEqual(type(pure_dict_with_refs['dict']), dict) self.assertIsInstance(pure_dict_with_refs['dict']['field_ref1'], ml_collections.FieldReference) self.assertIsInstance(pure_dict_with_refs['dict']['field_ref2'], ml_collections.FieldReference) self.assertIs(pure_dict_with_refs['dict']['field_ref1'], pure_dict_with_refs['dict']['field_ref2']) # Ensure FieldReferences in the dict are not the same as the FieldReferences # in the original ConfigDict. self.assertIsNot(pure_dict_with_refs['dict']['field_ref1'], cfg.dict['field_ref1']) def testToDictTypeUnsafe(self): """Tests interaction between ignore_type() and to_dict().""" cfg = ml_collections.ConfigDict() cfg.string = ml_collections.FieldReference(None, field_type=str) with cfg.ignore_type(): cfg.string = 1 self.assertEqual(1, cfg.to_dict(preserve_field_references=True)['string']) def testCopyAndResolveReferences(self): """Tests the .copy_and_resolve_references() method.""" cfg = ml_collections.ConfigDict() field = ml_collections.FieldReference('a string') int_field = ml_collections.FieldReference(5) cfg.dict = { 'float': 2.3, 'integer': 1, 'field_ref1': field, 'field_ref2': field, 'field_ref_int1': int_field, 'field_ref_int2': int_field + 5, 'placeholder': config_dict.placeholder(str), 'cfg': ml_collections.ConfigDict({ 'integer': 1, 'int_field': int_field }) } cfg.ref = cfg.dict cfg.self_ref = cfg cfg_resolved = cfg.copy_and_resolve_references() for field, value in [('float', 2.3), ('integer', 1), ('field_ref1', 'a string'), ('field_ref2', 'a string'), ('field_ref_int1', 5), ('field_ref_int2', 10), ('placeholder', None)]: self.assertEqual(getattr(cfg_resolved.dict, field), value) for field, value in [('integer', 1), ('int_field', 5)]: self.assertEqual(getattr(cfg_resolved.dict.cfg, field), value) self.assertIs(cfg_resolved, cfg_resolved['self_ref']) self.assertIs(cfg_resolved['dict'], cfg_resolved['ref']) def testCopyAndResolveReferencesConfigTypes(self): """Tests that .copy_and_resolve_references() handles special types.""" cfg_type_safe = ml_collections.ConfigDict() int_field = ml_collections.FieldReference(5) cfg_type_safe.field_ref1 = int_field cfg_type_safe.field_ref2 = int_field + 5 cfg_type_safe.lock() cfg_type_safe_locked_resolved = cfg_type_safe.copy_and_resolve_references() self.assertTrue(cfg_type_safe_locked_resolved.is_locked) self.assertTrue(cfg_type_safe_locked_resolved.is_type_safe) cfg = ml_collections.ConfigDict(type_safe=False) cfg.field_ref1 = int_field cfg.field_ref2 = int_field + 5 cfg_resolved = cfg.copy_and_resolve_references() self.assertFalse(cfg_resolved.is_locked) self.assertFalse(cfg_resolved.is_type_safe) cfg.lock() cfg_locked_resolved = cfg.copy_and_resolve_references() self.assertTrue(cfg_locked_resolved.is_locked) self.assertFalse(cfg_locked_resolved.is_type_safe) for resolved in [ cfg_type_safe_locked_resolved, cfg_resolved, cfg_locked_resolved ]: self.assertEqual(resolved.field_ref1, 5) self.assertEqual(resolved.field_ref2, 10) frozen_cfg = ml_collections.FrozenConfigDict(cfg_type_safe) frozen_cfg_resolved = frozen_cfg.copy_and_resolve_references() for resolved in [frozen_cfg, frozen_cfg_resolved]: self.assertEqual(resolved.field_ref1, 5) self.assertEqual(resolved.field_ref2, 10) self.assertIsInstance(resolved, ml_collections.FrozenConfigDict) def testInitConfigDict(self): """Tests initializing a ConfigDict on a ConfigDict.""" cfg = _get_test_config_dict() cfg_2 = ml_collections.ConfigDict(cfg) self.assertIsNot(cfg_2, cfg) self.assertIs(cfg_2.float, cfg.float) self.assertIs(cfg_2.dict, cfg.dict) # Ensure ConfigDict fields are initialized as is dict_with_cfg_field = {'cfg': cfg} cfg_3 = ml_collections.ConfigDict(dict_with_cfg_field) self.assertIs(cfg_3.cfg, cfg) # Now ensure it works with locking and type_safe cfg_4 = ml_collections.ConfigDict(cfg, type_safe=False) cfg_4.lock() self.assertEqual(cfg_4, ml_collections.ConfigDict(cfg_4)) def testInitReferenceStructure(self): """Ensures initialization preserves reference structure.""" x = [1, 2, 3] self_ref_dict = { 'float': 2.34, 'test_dict_1': _TEST_DICT, 'test_dict_2': _TEST_DICT, 'list': x } self_ref_dict['self'] = self_ref_dict self_ref_dict['self_fr'] = ml_collections.FieldReference(self_ref_dict) self_ref_cd = ml_collections.ConfigDict(self_ref_dict) self.assertIs(self_ref_cd.test_dict_1, self_ref_cd.test_dict_2) self.assertIs(self_ref_cd, self_ref_cd.self) self.assertIs(self_ref_cd, self_ref_cd.self_fr) self.assertIs(self_ref_cd.list, x) self.assertEqual(self_ref_cd, self_ref_cd.self) self_ref_cd.self.int = 1 self.assertEqual(self_ref_cd.int, 1) self_ref_cd_2 = ml_collections.ConfigDict(self_ref_cd) self.assertIsNot(self_ref_cd_2, self_ref_cd) self.assertIs(self_ref_cd_2.self, self_ref_cd_2) self.assertIs(self_ref_cd_2.test_dict_1, self_ref_cd.test_dict_1) def testInitFieldReference(self): """Tests initialization with FieldReferences.""" test_dict = dict(x=1, y=1) # Reference to a dict. reference = ml_collections.FieldReference(test_dict) cfg = ml_collections.ConfigDict() cfg.reference = reference self.assertIsInstance(cfg.reference, ml_collections.ConfigDict) self.assertEqual(test_dict['x'], cfg.reference.x) self.assertEqual(test_dict['y'], cfg.reference.y) # Reference to a ConfigDict. test_configdict = ml_collections.ConfigDict(test_dict) reference = ml_collections.FieldReference(test_configdict) cfg = ml_collections.ConfigDict() cfg.reference = reference test_configdict.x = 2 self.assertEqual(test_configdict.x, cfg.reference.x) self.assertEqual(test_configdict.y, cfg.reference.y) # Reference to a reference. reference_int = ml_collections.FieldReference(0) reference = ml_collections.FieldReference(reference_int) cfg = ml_collections.ConfigDict() cfg.reference = reference reference_int.set(1) self.assertEqual(reference_int.get(), cfg.reference) def testDeletingFields(self): """Tests whether it is possible to delete fields.""" cfg = ml_collections.ConfigDict() cfg.field1 = 123 cfg.field2 = 123 self.assertIn('field1', cfg) self.assertIn('field2', cfg) del cfg.field1 self.assertNotIn('field1', cfg) self.assertIn('field2', cfg) del cfg.field2 self.assertNotIn('field2', cfg) with self.assertRaises(AttributeError): del cfg.keys with self.assertRaises(KeyError): del cfg['keys'] def testDeletingNestedFields(self): """Tests whether it is possible to delete nested fields.""" cfg = ml_collections.ConfigDict({ 'a': { 'aa': [1, 2], }, 'b': { 'ba': { 'baa': 2, 'bab': 3, }, 'bb': {1, 2, 3}, }, }) self.assertIn('a', cfg) self.assertIn('aa', cfg.a) self.assertIn('baa', cfg.b.ba) del cfg['a.aa'] self.assertIn('a', cfg) self.assertNotIn('aa', cfg.a) del cfg['a'] self.assertNotIn('a', cfg) del cfg['b.ba.baa'] self.assertIn('ba', cfg.b) self.assertIn('bab', cfg.b.ba) self.assertNotIn('baa', cfg.b.ba) del cfg['b.ba'] self.assertNotIn('ba', cfg.b) self.assertIn('bb', cfg.b) with self.assertRaises(AttributeError): del cfg.keys with self.assertRaises(KeyError): del cfg['keys'] def testSetAttr(self): """Tests whether it is possible to override an attribute.""" cfg = ml_collections.ConfigDict() with self.assertRaises(AttributeError): cfg.__setattr__('__class__', 'abc') def testPickling(self): """Tests whether ConfigDict can be pickled and unpickled.""" cfg = _get_test_config_dict() cfg.lock() pickle_cfg = pickle.loads(pickle.dumps(cfg)) self.assertTrue(pickle_cfg.is_locked) self.assertIsInstance(pickle_cfg, ml_collections.ConfigDict) self.assertEqual(str(cfg), str(pickle_cfg)) def testPlaceholder(self): """Tests whether FieldReference works correctly as a placeholder.""" cfg_element = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict({ 'element': cfg_element, 'nested': { 'element': cfg_element } }) # Type mismatch. with self.assertRaises(TypeError): cfg.element = 'string' cfg.element = 1 self.assertEqual(cfg.element, cfg.nested.element) def testOptional(self): """Tests whether FieldReference works correctly as an optional field.""" # Type mismatch at construction. with self.assertRaises(TypeError): ml_collections.FieldReference(0, field_type=str) # None default and field_type. with self.assertRaises(ValueError): ml_collections.FieldReference(None) cfg = ml_collections.ConfigDict({ 'default': ml_collections.FieldReference(0), }) cfg.default = 1 self.assertEqual(cfg.default, 1) def testOptionalNoDefault(self): """Tests optional field with no default value.""" cfg = ml_collections.ConfigDict({ 'nodefault': ml_collections.FieldReference(None, field_type=str), }) # Type mismatch with field with no default value. with self.assertRaises(TypeError): cfg.nodefault = 1 cfg.nodefault = 'string' self.assertEqual(cfg.nodefault, 'string') def testGetType(self): """Tests whether types are correct for FieldReference fields.""" cfg = ml_collections.ConfigDict() cfg.integer = 123 cfg.ref = ml_collections.FieldReference(123) cfg.ref_nodefault = ml_collections.FieldReference(None, field_type=int) self.assertEqual(cfg.get_type('integer'), int) self.assertEqual(cfg.get_type('ref'), int) self.assertEqual(cfg.get_type('ref_nodefault'), int) # Check errors in case of misspelled key. with self.assertRaisesRegex(AttributeError, 'Did you.*ref_nodefault.*'): cfg.get_type('ref_nodefualt') with self.assertRaisesRegex(AttributeError, 'Did you.*integer.*'): cfg.get_type('integre') class ConfigDictUpdateTest(absltest.TestCase): def testUpdateSimple(self): """Tests updating from one ConfigDict to another.""" first = ml_collections.ConfigDict() first.x = 5 first.y = 'truman' first.q = 2.0 second = ml_collections.ConfigDict() second.x = 9 second.y = 'wilson' second.z = 'washington' first.update(second) self.assertEqual(first.x, 9) self.assertEqual(first.y, 'wilson') self.assertEqual(first.z, 'washington') self.assertEqual(first.q, 2.0) def testUpdateNothing(self): """Tests updating a ConfigDict with no arguments.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 cfg.update() self.assertLen(cfg, 2) self.assertEqual(cfg.x, 5) self.assertEqual(cfg.y, 9) def testUpdateFromDict(self): """Tests updating a ConfigDict from a dict.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 cfg.update({'x': 6, 'z': 2}) self.assertEqual(cfg.x, 6) self.assertEqual(cfg.y, 9) self.assertEqual(cfg.z, 2) def testUpdateFromKwargs(self): """Tests updating a ConfigDict from kwargs.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 cfg.update(x=6, z=2) self.assertEqual(cfg.x, 6) self.assertEqual(cfg.y, 9) self.assertEqual(cfg.z, 2) def testUpdateFromDictAndKwargs(self): """Tests updating a ConfigDict from a dict and kwargs.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 cfg.update({'x': 4, 'z': 2}, x=6) self.assertEqual(cfg.x, 6) # kwarg overrides value from dict self.assertEqual(cfg.y, 9) self.assertEqual(cfg.z, 2) def testUpdateFromMultipleDictTypeError(self): """Tests that updating a ConfigDict from two dicts raises a TypeError.""" cfg = ml_collections.ConfigDict() cfg.x = 5 cfg.y = 9 with self.assertRaisesRegex(TypeError, 'update expected at most 1 arguments, got 2'): cfg.update({'x': 4}, {'z': 2}) def testUpdateNested(self): """Tests updating a ConfigDict from a nested dict.""" cfg = ml_collections.ConfigDict() cfg.subcfg = ml_collections.ConfigDict() cfg.p = 5 cfg.q = 6 cfg.subcfg.y = 9 cfg.update({'p': 4, 'subcfg': {'y': 10, 'z': 5}}) self.assertEqual(cfg.p, 4) self.assertEqual(cfg.q, 6) self.assertEqual(cfg.subcfg.y, 10) self.assertEqual(cfg.subcfg.z, 5) def _assert_associated(self, cfg1, cfg2, key): self.assertEqual(cfg1[key], cfg2[key]) cfg1[key] = 1 cfg2[key] = 2 self.assertEqual(cfg1[key], 2) cfg1[key] = 3 self.assertEqual(cfg2[key], 3) def testUpdateFieldReference(self): """Tests updating to/from FieldReference fields.""" # Updating FieldReference... ref = ml_collections.FieldReference(1) cfg = ml_collections.ConfigDict(dict(a=ref, b=ref)) # from value. cfg.update(ml_collections.ConfigDict(dict(a=2))) self.assertEqual(cfg.a, 2) self.assertEqual(cfg.b, 2) # from FieldReference. error_message = 'Cannot update a FieldReference from another FieldReference' with self.assertRaisesRegex(TypeError, error_message): cfg.update( ml_collections.ConfigDict(dict(a=ml_collections.FieldReference(2)))) with self.assertRaisesRegex(TypeError, error_message): cfg.update( ml_collections.ConfigDict(dict(b=ml_collections.FieldReference(2)))) # Updating empty ConfigDict with FieldReferences. ref = ml_collections.FieldReference(1) cfg_from = ml_collections.ConfigDict(dict(a=ref, b=ref)) cfg = ml_collections.ConfigDict() cfg.update(cfg_from) self._assert_associated(cfg, cfg_from, 'a') self._assert_associated(cfg, cfg_from, 'b') # Updating values with FieldReferences. ref = ml_collections.FieldReference(1) cfg_from = ml_collections.ConfigDict(dict(a=ref, b=ref)) cfg = ml_collections.ConfigDict(dict(a=2, b=3)) cfg.update(cfg_from) self._assert_associated(cfg, cfg_from, 'a') self._assert_associated(cfg, cfg_from, 'b') def testUpdateFromFlattened(self): cfg = ml_collections.ConfigDict({'a': 1, 'b': {'c': {'d': 2}}}) updates = {'a': 2, 'b.c.d': 3} cfg.update_from_flattened_dict(updates) self.assertEqual(cfg.a, 2) self.assertEqual(cfg.b.c.d, 3) def testUpdateFromFlattenedWithPrefix(self): cfg = ml_collections.ConfigDict({'a': 1, 'b': {'c': {'d': 2}}}) updates = {'a': 2, 'b.c.d': 3} cfg.b.update_from_flattened_dict(updates, 'b.') self.assertEqual(cfg.a, 1) self.assertEqual(cfg.b.c.d, 3) def testUpdateFromFlattenedNotFound(self): cfg = ml_collections.ConfigDict({'a': 1, 'b': {'c': {'d': 2}}}) updates = {'a': 2, 'b.d.e': 3} with self.assertRaisesRegex( KeyError, 'Key "b.d.e" cannot be set as "b.d" was not found.'): cfg.update_from_flattened_dict(updates) def testUpdateFromFlattenedWrongType(self): cfg = ml_collections.ConfigDict({'a': 1, 'b': {'c': {'d': 2}}}) updates = {'a.b.c': 2} with self.assertRaisesRegex( KeyError, 'Key "a.b.c" cannot be updated as "a" is not a ConfigDict.'): cfg.update_from_flattened_dict(updates) def testUpdateFromFlattenedTupleListConversion(self): cfg = ml_collections.ConfigDict({ 'a': 1, 'b': { 'c': { 'd': (1, 2, 3, 4, 5), } } }) updates = { 'b.c.d': [2, 4, 6, 8], } cfg.update_from_flattened_dict(updates) self.assertIsInstance(cfg.b.c.d, tuple) self.assertEqual(cfg.b.c.d, (2, 4, 6, 8)) def testDecodeError(self): # ConfigDict containing two strings with incompatible encodings. cfg = ml_collections.ConfigDict({ 'dill': pickle.dumps(_test_function, protocol=pickle.HIGHEST_PROTOCOL), 'unicode': u'unicode string' }) expected_error = config_dict.JSONDecodeError if six.PY2 else TypeError with self.assertRaises(expected_error): cfg.to_json() def testConvertDict(self): """Test automatic conversion, or not, of dict to ConfigDict.""" cfg = ml_collections.ConfigDict() cfg.a = dict(b=dict(c=0)) self.assertIsInstance(cfg.a, ml_collections.ConfigDict) self.assertIsInstance(cfg.a.b, ml_collections.ConfigDict) cfg = ml_collections.ConfigDict(convert_dict=False) cfg.a = dict(b=dict(c=0)) self.assertNotIsInstance(cfg.a, ml_collections.ConfigDict) self.assertIsInstance(cfg.a, dict) self.assertIsInstance(cfg.a['b'], dict) def testConvertDictInInitialValue(self): """Test automatic conversion, or not, of dict to ConfigDict.""" initial_dict = dict(a=dict(b=dict(c=0))) cfg = ml_collections.ConfigDict(initial_dict) self.assertIsInstance(cfg.a, ml_collections.ConfigDict) self.assertIsInstance(cfg.a.b, ml_collections.ConfigDict) cfg = ml_collections.ConfigDict(initial_dict, convert_dict=False) self.assertNotIsInstance(cfg.a, ml_collections.ConfigDict) self.assertIsInstance(cfg.a, dict) self.assertIsInstance(cfg.a['b'], dict) def testConvertDictTypeCompat(self): """Test that automatic conversion to ConfigDict doesn't trigger type errors.""" cfg = ml_collections.ConfigDict() cfg.a = {} self.assertIsInstance(cfg.a, ml_collections.ConfigDict) # This checks that dict to configdict casting doesn't produce type mismatch. cfg.a = {} def testYamlNoConvert(self): """Test deserialisation from YAML without convert dict. This checks backward compatibility of deserialisation. """ cfg = ml_collections.ConfigDict(dict(a=1)) self.assertTrue(yaml.load(cfg.to_yaml())._convert_dict) def testRecursiveRename(self): """Test recursive_rename. The dictionary should be the same but with the specified name changed. """ cfg = ml_collections.ConfigDict(_TEST_DICT) new_cfg = config_dict.recursive_rename(cfg, 'float', 'double') # Check that the new config has float changed to double as we expect self.assertEqual(new_cfg.to_dict(), _TEST_DICT_CHANGE_FLOAT_NAME) # Check that the original config is unchanged self.assertEqual(cfg.to_dict(), _TEST_DICT) def testGetOnewayRef(self): cfg = config_dict.create(a=1) cfg.b = cfg.get_oneway_ref('a') cfg.a = 2 self.assertEqual(2, cfg.b) cfg.b = 3 self.assertEqual(2, cfg.a) self.assertEqual(3, cfg.b) class CreateTest(absltest.TestCase): def testBasic(self): config = config_dict.create(a=1, b='b') dct = {'a': 1, 'b': 'b'} self.assertEqual(config.to_dict(), dct) def testNested(self): config = config_dict.create( data=config_dict.create(game='freeway'), model=config_dict.create(num_hidden=1000)) dct = {'data': {'game': 'freeway'}, 'model': {'num_hidden': 1000}} self.assertEqual(config.to_dict(), dct) class PlaceholderTest(absltest.TestCase): def testBasic(self): config = config_dict.create(a=1, b=config_dict.placeholder(int)) self.assertEqual(config.to_dict(), {'a': 1, 'b': None}) config.b = 5 self.assertEqual(config.to_dict(), {'a': 1, 'b': 5}) def testTypeChecking(self): config = config_dict.create(a=1, b=config_dict.placeholder(int)) with self.assertRaises(TypeError): config.b = 'chutney' def testRequired(self): config = config_dict.create(a=config_dict.required_placeholder(int)) ref = config.get_ref('a') with self.assertRaises(config_dict.RequiredValueError): config.a # pylint: disable=pointless-statement with self.assertRaises(config_dict.RequiredValueError): config.to_dict() with self.assertRaises(config_dict.RequiredValueError): ref.get() config.a = 10 self.assertEqual(config.to_dict(), {'a': 10}) self.assertEqual(str(config), yaml.dump({'a': 10})) # Reset to None and check we still get an error. config.a = None with self.assertRaises(config_dict.RequiredValueError): config.a # pylint: disable=pointless-statement # Set to a different value using the reference obtained calling get_ref(). ref.set(5) self.assertEqual(config.to_dict(), {'a': 5}) self.assertEqual(str(config), yaml.dump({'a': 5})) # dict placeholder. test_dict = {'field': 10} config = config_dict.create( a=config_dict.required_placeholder(dict), b=ml_collections.FieldReference(test_dict.copy())) # ConfigDict initialization converts dict to ConfigDict. self.assertEqual(test_dict, config.b.to_dict()) config.a = test_dict self.assertEqual(test_dict, config.a) class CycleTest(absltest.TestCase): def testCycle(self): config = config_dict.create(a=1) config.b = config.get_ref('a') + config.get_ref('a') self.assertFalse(config.get_ref('b').has_cycle()) with self.assertRaises(config_dict.MutabilityError): config.a = config.get_ref('a') with self.assertRaises(config_dict.MutabilityError): config.a = config.get_ref('b') if __name__ == '__main__': absltest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/tests/field_reference_test.py0000644162426002575230000005534300000000000032640 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for ml_collections.FieldReference.""" import operator from absl.testing import absltest from absl.testing import parameterized import ml_collections from ml_collections.config_dict import config_dict class FieldReferenceTest(parameterized.TestCase): def _test_binary_operator(self, initial_value, other_value, op, true_value, new_initial_value, new_true_value, assert_fn=None): """Helper for testing binary operators. Generally speaking this checks that: 1. `op(initial_value, other_value) COMP true_value` 2. `op(new_initial_value, other_value) COMP new_true_value where `COMP` is the comparison function defined by `assert_fn`. Args: initial_value: Initial value for the `FieldReference`, this is the first argument for the binary operator. other_value: The second argument for the binary operator. op: The binary operator. true_value: The expected output of the binary operator. new_initial_value: The value that the `FieldReference` is changed to. new_true_value: The expected output of the binary operator after the `FieldReference` has changed. assert_fn: Function used to check the output values. """ if assert_fn is None: assert_fn = self.assertEqual ref = ml_collections.FieldReference(initial_value) new_ref = op(ref, other_value) assert_fn(new_ref.get(), true_value) config = ml_collections.ConfigDict() config.a = initial_value config.b = other_value config.result = op(config.get_ref('a'), config.b) assert_fn(config.result, true_value) config.a = new_initial_value assert_fn(config.result, new_true_value) def _test_unary_operator(self, initial_value, op, true_value, new_initial_value, new_true_value, assert_fn=None): """Helper for testing unary operators. Generally speaking this checks that: 1. `op(initial_value) COMP true_value` 2. `op(new_initial_value) COMP new_true_value where `COMP` is the comparison function defined by `assert_fn`. Args: initial_value: Initial value for the `FieldReference`, this is the first argument for the unary operator. op: The unary operator. true_value: The expected output of the unary operator. new_initial_value: The value that the `FieldReference` is changed to. new_true_value: The expected output of the unary operator after the `FieldReference` has changed. assert_fn: Function used to check the output values. """ if assert_fn is None: assert_fn = self.assertEqual ref = ml_collections.FieldReference(initial_value) new_ref = op(ref) assert_fn(new_ref.get(), true_value) config = ml_collections.ConfigDict() config.a = initial_value config.result = op(config.get_ref('a')) assert_fn(config.result, true_value) config.a = new_initial_value assert_fn(config.result, new_true_value) def testBasic(self): ref = ml_collections.FieldReference(1) self.assertEqual(ref.get(), 1) def testGetRef(self): config = ml_collections.ConfigDict() config.a = 1. config.b = config.get_ref('a') + 10 config.c = config.get_ref('b') + 10 self.assertEqual(config.c, 21.0) def testFunction(self): def fn(x): return x + 5 config = ml_collections.ConfigDict() config.a = 1 config.b = fn(config.get_ref('a')) config.c = fn(config.get_ref('b')) self.assertEqual(config.b, 6) self.assertEqual(config.c, 11) config.a = 2 self.assertEqual(config.b, 7) self.assertEqual(config.c, 12) def testCycles(self): config = ml_collections.ConfigDict() config.a = 1. config.b = config.get_ref('a') + 10 config.c = config.get_ref('b') + 10 self.assertEqual(config.b, 11.0) self.assertEqual(config.c, 21.0) # Introduce a cycle with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = config.get_ref('c') - 1.0 # Introduce a cycle on second operand with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = ml_collections.FieldReference(5.0) + config.get_ref('c') # We can create multiple FieldReferences that all point to the same object l = [0] config = ml_collections.ConfigDict() config.a = l config.b = l config.c = config.get_ref('a') + ['c'] config.d = config.get_ref('b') + ['d'] self.assertEqual(config.c, [0, 'c']) self.assertEqual(config.d, [0, 'd']) # Make sure nothing was mutated self.assertEqual(l, [0]) self.assertEqual(config.c, [0, 'c']) config.a = [1] config.b = [2] self.assertEqual(l, [0]) self.assertEqual(config.c, [1, 'c']) self.assertEqual(config.d, [2, 'd']) @parameterized.parameters( { 'initial_value': 1, 'other_value': 2, 'true_value': 3, 'new_initial_value': 10, 'new_true_value': 12 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 4.5, 'new_initial_value': 3.7, 'new_true_value': 6.2 }, { 'initial_value': 'hello, ', 'other_value': 'world!', 'true_value': 'hello, world!', 'new_initial_value': 'foo, ', 'new_true_value': 'foo, world!' }, { 'initial_value': ['hello'], 'other_value': ['world'], 'true_value': ['hello', 'world'], 'new_initial_value': ['foo'], 'new_true_value': ['foo', 'world'] }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 15.0, 'new_initial_value': 12, 'new_true_value': 17.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 19.0 }, { 'initial_value': 5.0, 'other_value': config_dict.placeholder(float), 'true_value': None, 'new_initial_value': 8.0, 'new_true_value': None }, { 'initial_value': config_dict.placeholder(str), 'other_value': 'tail', 'true_value': None, 'new_initial_value': 'head', 'new_true_value': 'headtail' }) def testAdd(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.add, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 5, 'other_value': 3, 'true_value': 2, 'new_initial_value': -1, 'new_true_value': -4 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': -0.5, 'new_initial_value': 12.3, 'new_true_value': 9.8 }, { 'initial_value': set(['hello', 123, 4.5]), 'other_value': set([123]), 'true_value': set(['hello', 4.5]), 'new_initial_value': set([123]), 'new_true_value': set([]) }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 5.0, 'new_initial_value': 12, 'new_true_value': 7.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 5.0 }) def testSub(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.sub, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 1, 'other_value': 2, 'true_value': 2, 'new_initial_value': 3, 'new_true_value': 6 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 5.0, 'new_initial_value': 3.5, 'new_true_value': 8.75 }, { 'initial_value': ['hello'], 'other_value': 3, 'true_value': ['hello', 'hello', 'hello'], 'new_initial_value': ['foo'], 'new_true_value': ['foo', 'foo', 'foo'] }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 50.0, 'new_initial_value': 1, 'new_true_value': 5.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 84.0 }) def testMul(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.mul, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1.5, 'new_initial_value': 10, 'new_true_value': 5.0 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 0.8, 'new_initial_value': 6.3, 'new_true_value': 2.52 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5.0), 'true_value': 2.0, 'new_initial_value': 13, 'new_true_value': 2.6 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 17.5, 'new_true_value': 2.5 }) def testTrueDiv(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.truediv, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1, 'new_initial_value': 7, 'new_true_value': 3 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5), 'true_value': 2, 'new_initial_value': 28, 'new_true_value': 5 }, { 'initial_value': config_dict.placeholder(int), 'other_value': 7, 'true_value': None, 'new_initial_value': 25, 'new_true_value': 3 }) def testFloorDiv(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.floordiv, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 9, 'new_initial_value': 10, 'new_true_value': 100 }, { 'initial_value': 2.7, 'other_value': 3.2, 'true_value': 24.0084457245, 'new_initial_value': 6.5, 'new_true_value': 399.321543621 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5), 'true_value': 1e5, 'new_initial_value': 2, 'new_true_value': 32 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 3.0, 'true_value': None, 'new_initial_value': 7.0, 'new_true_value': 343.0 }) def testPow(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator( initial_value, other_value, operator.pow, true_value, new_initial_value, new_true_value, assert_fn=self.assertAlmostEqual) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1, 'new_initial_value': 10, 'new_true_value': 0 }, { 'initial_value': 5.3, 'other_value': 3.2, 'true_value': 2.0999999999999996, 'new_initial_value': 77, 'new_true_value': 0.2 }, { 'initial_value': ml_collections.FieldReference(10), 'other_value': ml_collections.FieldReference(5), 'true_value': 0, 'new_initial_value': 32, 'new_true_value': 2 }, { 'initial_value': config_dict.placeholder(int), 'other_value': 7, 'true_value': None, 'new_initial_value': 25, 'new_true_value': 4 }) def testMod(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator( initial_value, other_value, operator.mod, true_value, new_initial_value, new_true_value, assert_fn=self.assertAlmostEqual) @parameterized.parameters( { 'initial_value': True, 'other_value': True, 'true_value': True, 'new_initial_value': False, 'new_true_value': False }, { 'initial_value': ml_collections.FieldReference(False), 'other_value': ml_collections.FieldReference(False), 'true_value': False, 'new_initial_value': True, 'new_true_value': False }, { 'initial_value': config_dict.placeholder(bool), 'other_value': True, 'true_value': None, 'new_initial_value': False, 'new_true_value': False }) def testAnd(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.and_, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': False, 'other_value': False, 'true_value': False, 'new_initial_value': True, 'new_true_value': True }, { 'initial_value': ml_collections.FieldReference(True), 'other_value': ml_collections.FieldReference(True), 'true_value': True, 'new_initial_value': False, 'new_true_value': True }, { 'initial_value': config_dict.placeholder(bool), 'other_value': False, 'true_value': None, 'new_initial_value': True, 'new_true_value': True }) def testOr(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.or_, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': False, 'other_value': True, 'true_value': True, 'new_initial_value': True, 'new_true_value': False }, { 'initial_value': ml_collections.FieldReference(True), 'other_value': ml_collections.FieldReference(True), 'true_value': False, 'new_initial_value': False, 'new_true_value': True }, { 'initial_value': config_dict.placeholder(bool), 'other_value': True, 'true_value': None, 'new_initial_value': True, 'new_true_value': False }) def testXor(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.xor, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'true_value': -3, 'new_initial_value': -22, 'new_true_value': 22 }, { 'initial_value': 15.3, 'true_value': -15.3, 'new_initial_value': -0.2, 'new_true_value': 0.2 }, { 'initial_value': ml_collections.FieldReference(7), 'true_value': ml_collections.FieldReference(-7), 'new_initial_value': 123, 'new_true_value': -123 }, { 'initial_value': config_dict.placeholder(int), 'true_value': None, 'new_initial_value': -6, 'new_true_value': 6 }) def testNeg(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, operator.neg, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'true_value': 3, 'new_initial_value': -101, 'new_true_value': 101 }, { 'initial_value': -15.3, 'true_value': 15.3, 'new_initial_value': 7.3, 'new_true_value': 7.3 }, { 'initial_value': ml_collections.FieldReference(-7), 'true_value': ml_collections.FieldReference(7), 'new_initial_value': 3, 'new_true_value': 3 }, { 'initial_value': config_dict.placeholder(float), 'true_value': None, 'new_initial_value': -6.25, 'new_true_value': 6.25 }) def testAbs(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, operator.abs, true_value, new_initial_value, new_true_value) def testToInt(self): self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27) ref = ml_collections.FieldReference(64.7) ref = ref.to_int() self.assertEqual(ref.get(), 64) self.assertEqual(ref._field_type, int) def testToFloat(self): self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0) ref = ml_collections.FieldReference(647) ref = ref.to_float() self.assertEqual(ref.get(), 647.0) self.assertEqual(ref._field_type, float) def testToString(self): self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0') ref = ml_collections.FieldReference(647) ref = ref.to_str() self.assertEqual(ref.get(), '647') self.assertEqual(ref._field_type, str) def testSetValue(self): ref = ml_collections.FieldReference(1.0) other = ml_collections.FieldReference(3) ref_plus_other = ref + other self.assertEqual(ref_plus_other.get(), 4.0) ref.set(2.5) self.assertEqual(ref_plus_other.get(), 5.5) other.set(110) self.assertEqual(ref_plus_other.get(), 112.5) # Type checking with self.assertRaises(TypeError): other.set('this is a string') with self.assertRaises(TypeError): other.set(ml_collections.FieldReference('this is a string')) with self.assertRaises(TypeError): other.set(ml_collections.FieldReference(None, field_type=str)) def testSetResult(self): ref = ml_collections.FieldReference(1.0) result = ref + 1.0 second_result = result + 1.0 self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 2.0) self.assertEqual(second_result.get(), 3.0) ref.set(2.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 3.0) self.assertEqual(second_result.get(), 4.0) result.set(4.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0) # All references are broken at this point. ref.set(1.0) self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0) def testTypeChecking(self): ref = ml_collections.FieldReference(1) string_ref = ml_collections.FieldReference('a') x = ref + string_ref with self.assertRaises(TypeError): x.get() def testNoType(self): self.assertRaisesRegex(TypeError, 'field_type should be a type.*', ml_collections.FieldReference, None, 0) def testEqual(self): # Simple case ref1 = ml_collections.FieldReference(1) ref2 = ml_collections.FieldReference(1) ref3 = ml_collections.FieldReference(2) self.assertEqual(ref1, 1) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, 2) self.assertNotEqual(ref1, ref3) # ConfigDict inside FieldReference ref1 = ml_collections.FieldReference(ml_collections.ConfigDict({'a': 1})) ref2 = ml_collections.FieldReference(ml_collections.ConfigDict({'a': 1})) ref3 = ml_collections.FieldReference(ml_collections.ConfigDict({'a': 2})) self.assertEqual(ref1, ml_collections.ConfigDict({'a': 1})) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, ml_collections.ConfigDict({'a': 2})) self.assertNotEqual(ref1, ref3) def testLessEqual(self): # Simple case ref1 = ml_collections.FieldReference(1) ref2 = ml_collections.FieldReference(1) ref3 = ml_collections.FieldReference(2) self.assertLessEqual(ref1, 1) self.assertLessEqual(ref1, 2) self.assertLessEqual(0, ref1) self.assertLessEqual(1, ref1) self.assertGreater(ref1, 0) self.assertLessEqual(ref1, ref1) self.assertLessEqual(ref1, ref2) self.assertLessEqual(ref1, ref3) self.assertGreater(ref3, ref1) def testControlFlowError(self): ref1 = ml_collections.FieldReference(True) ref2 = ml_collections.FieldReference(False) with self.assertRaises(NotImplementedError): if ref1: pass with self.assertRaises(NotImplementedError): _ = ref1 and ref2 with self.assertRaises(NotImplementedError): _ = ref1 or ref2 with self.assertRaises(NotImplementedError): _ = not ref1 if __name__ == '__main__': absltest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_dict/tests/frozen_config_dict_test.py0000644162426002575230000003467500000000000033377 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for ml_collections.FrozenConfigDict.""" import collections as python_collections import copy import pickle from absl.testing import absltest import ml_collections _TEST_DICT = { 'int': 2, 'list': [1, 2], 'nested_list': [[1, [2]]], 'set': {1, 2}, 'tuple': (1, 2), 'frozenset': frozenset({1, 2}), 'dict': { 'float': -1.23, 'list': [1, 2], 'dict': {}, 'tuple_containing_list': (1, 2, (3, [4, 5], (6, 7))), 'list_containing_tuple': [1, 2, [3, 4], (5, 6)], }, 'ref': ml_collections.FieldReference({'int': 0}) } def _test_dict_deepcopy(): return copy.deepcopy(_TEST_DICT) def _test_configdict(): return ml_collections.ConfigDict(_TEST_DICT) def _test_frozenconfigdict(): return ml_collections.FrozenConfigDict(_TEST_DICT) class FrozenConfigDictTest(absltest.TestCase): """Tests FrozenConfigDict in config flags library.""" def assertFrozenRaisesValueError(self, input_list): """Assert initialization on all elements of input_list raise ValueError.""" for initial_dictionary in input_list: with self.assertRaises(ValueError): _ = ml_collections.FrozenConfigDict(initial_dictionary) def testBasicEquality(self): """Tests basic equality with different types of initialization.""" fcd = _test_frozenconfigdict() fcd_cd = ml_collections.FrozenConfigDict(_test_configdict()) fcd_fcd = ml_collections.FrozenConfigDict(fcd) self.assertEqual(fcd, fcd_cd) self.assertEqual(fcd, fcd_fcd) def testImmutability(self): """Tests immutability of frozen config.""" fcd = _test_frozenconfigdict() self.assertEqual(fcd.list, tuple(_TEST_DICT['list'])) self.assertEqual(fcd.tuple, _TEST_DICT['tuple']) self.assertEqual(fcd.set, frozenset(_TEST_DICT['set'])) self.assertEqual(fcd.frozenset, _TEST_DICT['frozenset']) # Must manually check set to frozenset conversion, since Python == does not self.assertIsInstance(fcd.set, frozenset) self.assertEqual(fcd.dict.list, tuple(_TEST_DICT['dict']['list'])) self.assertNotEqual(fcd.dict.tuple_containing_list, _TEST_DICT['dict']['tuple_containing_list']) self.assertEqual(fcd.dict.tuple_containing_list[2][1], tuple(_TEST_DICT['dict']['tuple_containing_list'][2][1])) self.assertIsInstance(fcd.dict, ml_collections.FrozenConfigDict) with self.assertRaises(AttributeError): fcd.newitem = 0 with self.assertRaises(AttributeError): fcd.dict.int = 0 with self.assertRaises(AttributeError): fcd['newitem'] = 0 with self.assertRaises(AttributeError): del fcd.int with self.assertRaises(AttributeError): del fcd['int'] def testLockAndFreeze(self): """Ensures .lock() and .freeze() raise errors.""" fcd = _test_frozenconfigdict() self.assertFalse(fcd.is_locked) self.assertFalse(fcd.as_configdict().is_locked) with self.assertRaises(AttributeError): fcd.lock() with self.assertRaises(AttributeError): fcd.unlock() with self.assertRaises(AttributeError): fcd.freeze() with self.assertRaises(AttributeError): fcd.unfreeze() def testInitConfigDict(self): """Tests that ConfigDict initialization handles FrozenConfigDict. Initializing a ConfigDict on a dictionary with FrozenConfigDict values should unfreeze these values. """ dict_without_fcd_node = _test_dict_deepcopy() dict_without_fcd_node.pop('ref') dict_with_fcd_node = copy.deepcopy(dict_without_fcd_node) dict_with_fcd_node['dict'] = ml_collections.FrozenConfigDict( dict_with_fcd_node['dict']) cd_without_fcd_node = ml_collections.ConfigDict(dict_without_fcd_node) cd_with_fcd_node = ml_collections.ConfigDict(dict_with_fcd_node) fcd_without_fcd_node = ml_collections.FrozenConfigDict( dict_without_fcd_node) fcd_with_fcd_node = ml_collections.FrozenConfigDict(dict_with_fcd_node) self.assertEqual(cd_without_fcd_node, cd_with_fcd_node) self.assertEqual(fcd_without_fcd_node, fcd_with_fcd_node) def testInitCopying(self): """Tests that initialization copies when and only when necessary. Ensures copying only occurs when converting mutable type to immutable type, regardless of whether the FrozenConfigDict is initialized by a dict or a FrozenConfigDict. Also ensures no copying occurs when converting from FrozenConfigDict back to ConfigDict. """ fcd = _test_frozenconfigdict() # These should be uncopied when creating fcd fcd_unchanged_from_test_dict = [ (_TEST_DICT['tuple'], fcd.tuple), (_TEST_DICT['frozenset'], fcd.frozenset), (_TEST_DICT['dict']['tuple_containing_list'][2][2], fcd.dict.tuple_containing_list[2][2]), (_TEST_DICT['dict']['list_containing_tuple'][3], fcd.dict.list_containing_tuple[3]) ] # These should be copied when creating fcd fcd_different_from_test_dict = [ (_TEST_DICT['list'], fcd.list), (_TEST_DICT['dict']['tuple_containing_list'][2][1], fcd.dict.tuple_containing_list[2][1]) ] for (x, y) in fcd_unchanged_from_test_dict: self.assertEqual(id(x), id(y)) for (x, y) in fcd_different_from_test_dict: self.assertNotEqual(id(x), id(y)) # Also make sure that converting back to ConfigDict makes no copies self.assertEqual( id(_TEST_DICT['dict']['tuple_containing_list']), id(ml_collections.ConfigDict(fcd).dict.tuple_containing_list)) def testAsConfigDict(self): """Tests that converting FrozenConfigDict to ConfigDict works correctly. In particular, ensures that FrozenConfigDict does the inverse of ConfigDict regarding type_safe, lock, and attribute mutability. """ # First ensure conversion to ConfigDict works on empty FrozenConfigDict self.assertEqual( ml_collections.ConfigDict(ml_collections.FrozenConfigDict()), ml_collections.ConfigDict()) cd = _test_configdict() cd_fcd_cd = ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd)) self.assertEqual(cd, cd_fcd_cd) # Make sure locking is respected cd.lock() self.assertEqual( cd, ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd))) # Make sure type_safe is respected cd = ml_collections.ConfigDict(_TEST_DICT, type_safe=False) self.assertEqual( cd, ml_collections.ConfigDict(ml_collections.FrozenConfigDict(cd))) def testInitSelfReferencing(self): """Ensure initialization fails on self-referencing dicts.""" self_ref = {} self_ref['self'] = self_ref parent_ref = {'dict': {}} parent_ref['dict']['parent'] = parent_ref tuple_parent_ref = {'dict': {}} tuple_parent_ref['dict']['tuple'] = (1, 2, tuple_parent_ref) attribute_cycle = {'dict': copy.deepcopy(self_ref)} self.assertFrozenRaisesValueError( [self_ref, parent_ref, tuple_parent_ref, attribute_cycle]) def testInitCycles(self): """Ensure initialization fails if an attribute of input is cyclic.""" inner_cyclic_list = [1, 2] cyclic_list = [3, inner_cyclic_list] inner_cyclic_list.append(cyclic_list) cyclic_tuple = tuple(cyclic_list) test_dict_cyclic_list = _test_dict_deepcopy() test_dict_cyclic_tuple = _test_dict_deepcopy() test_dict_cyclic_list['cyclic_list'] = cyclic_list test_dict_cyclic_tuple['dict']['cyclic_tuple'] = cyclic_tuple self.assertFrozenRaisesValueError( [test_dict_cyclic_list, test_dict_cyclic_tuple]) def testInitDictInList(self): """Ensure initialization fails on dict and ConfigDict in lists/tuples.""" list_containing_dict = {'list': [1, 2, 3, {'a': 4, 'b': 5}]} tuple_containing_dict = {'tuple': (1, 2, 3, {'a': 4, 'b': 5})} list_containing_cd = {'list': [1, 2, 3, _test_configdict()]} tuple_containing_cd = {'tuple': (1, 2, 3, _test_configdict())} fr_containing_list_containing_dict = { 'fr': ml_collections.FieldReference([1, { 'a': 2 }]) } self.assertFrozenRaisesValueError([ list_containing_dict, tuple_containing_dict, list_containing_cd, tuple_containing_cd, fr_containing_list_containing_dict ]) def testInitFieldReferenceInList(self): """Ensure initialization fails on FieldReferences in lists/tuples.""" list_containing_fr = {'list': [1, 2, 3, ml_collections.FieldReference(4)]} tuple_containing_fr = { 'tuple': (1, 2, 3, ml_collections.FieldReference('a')) } self.assertFrozenRaisesValueError([list_containing_fr, tuple_containing_fr]) def testInitInvalidAttributeName(self): """Ensure initialization fails on attributes with invalid names.""" dot_name = {'dot.name': None} immutable_name = {'__hash__': None} with self.assertRaises(ValueError): ml_collections.FrozenConfigDict(dot_name) with self.assertRaises(AttributeError): ml_collections.FrozenConfigDict(immutable_name) def testFieldReferenceResolved(self): """Tests that FieldReferences are resolved.""" cfg = ml_collections.ConfigDict({'fr': ml_collections.FieldReference(1)}) frozen_cfg = ml_collections.FrozenConfigDict(cfg) self.assertNotIsInstance(frozen_cfg._fields['fr'], ml_collections.FieldReference) hash(frozen_cfg) # with FieldReference resolved, frozen_cfg is hashable def testFieldReferenceCycle(self): """Tests that FieldReferences may not contain reference cycles.""" frozenset_fr = {'frozenset': frozenset({1, 2})} frozenset_fr['fr'] = ml_collections.FieldReference( frozenset_fr['frozenset']) list_fr = {'list': [1, 2]} list_fr['fr'] = ml_collections.FieldReference(list_fr['list']) cyclic_fr = {'a': 1} cyclic_fr['fr'] = ml_collections.FieldReference(cyclic_fr) cyclic_fr_parent = {'dict': {}} cyclic_fr_parent['dict']['fr'] = ml_collections.FieldReference( cyclic_fr_parent) # FieldReference is allowed to point to non-cyclic objects: _ = ml_collections.FrozenConfigDict(frozenset_fr) _ = ml_collections.FrozenConfigDict(list_fr) # But not cycles: self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent]) def testDeepCopy(self): """Ensure deepcopy works and does not affect equality.""" fcd = _test_frozenconfigdict() fcd_deepcopy = copy.deepcopy(fcd) self.assertEqual(fcd, fcd_deepcopy) def testEquals(self): """Tests that __eq__() respects hidden mutability.""" fcd = _test_frozenconfigdict() # First, ensure __eq__() returns False when comparing to other types self.assertNotEqual(fcd, (1, 2)) self.assertNotEqual(fcd, fcd.as_configdict()) list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) fcd_list_to_tuple = ml_collections.FrozenConfigDict(list_to_tuple) set_to_frozenset = _test_dict_deepcopy() set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) fcd_set_to_frozenset = ml_collections.FrozenConfigDict(set_to_frozenset) self.assertNotEqual(fcd, fcd_list_to_tuple) # Because set == frozenset in Python: self.assertEqual(fcd, fcd_set_to_frozenset) # Items are not affected by hidden mutability self.assertCountEqual(fcd.items(), fcd_list_to_tuple.items()) self.assertCountEqual(fcd.items(), fcd_set_to_frozenset.items()) def testEqualsAsConfigDict(self): """Tests that eq_as_configdict respects hidden mutability but not type.""" fcd = _test_frozenconfigdict() # First, ensure eq_as_configdict() returns True with an equal ConfigDict but # False for other types. self.assertFalse(fcd.eq_as_configdict([1, 2])) self.assertTrue(fcd.eq_as_configdict(fcd.as_configdict())) empty_fcd = ml_collections.FrozenConfigDict() self.assertTrue(empty_fcd.eq_as_configdict(ml_collections.ConfigDict())) # Now, ensure it has the same immutability detection as __eq__(). list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) fcd_list_to_tuple = ml_collections.FrozenConfigDict(list_to_tuple) set_to_frozenset = _test_dict_deepcopy() set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) fcd_set_to_frozenset = ml_collections.FrozenConfigDict(set_to_frozenset) self.assertFalse(fcd.eq_as_configdict(fcd_list_to_tuple)) # Because set == frozenset in Python: self.assertTrue(fcd.eq_as_configdict(fcd_set_to_frozenset)) def testHash(self): """Ensures __hash__() respects hidden mutability.""" list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) self.assertEqual( hash(_test_frozenconfigdict()), hash(ml_collections.FrozenConfigDict(_test_dict_deepcopy()))) self.assertNotEqual( hash(_test_frozenconfigdict()), hash(ml_collections.FrozenConfigDict(list_to_tuple))) # Ensure Python realizes FrozenConfigDict is hashable self.assertIsInstance(_test_frozenconfigdict(), python_collections.Hashable) def testUnhashableType(self): """Ensures __hash__() fails if FrozenConfigDict has unhashable value.""" unhashable_fcd = ml_collections.FrozenConfigDict( {'unhashable': bytearray()}) with self.assertRaises(TypeError): hash(unhashable_fcd) def testToDict(self): """Ensure to_dict() does not care about hidden mutability.""" list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) self.assertEqual(_test_frozenconfigdict().to_dict(), ml_collections.FrozenConfigDict(list_to_tuple).to_dict()) def testPickle(self): """Make sure FrozenConfigDict can be dumped and loaded with pickle.""" fcd = _test_frozenconfigdict() locked_fcd = ml_collections.FrozenConfigDict(_test_configdict().lock()) unpickled_fcd = pickle.loads(pickle.dumps(fcd)) unpickled_locked_fcd = pickle.loads(pickle.dumps(locked_fcd)) self.assertEqual(fcd, unpickled_fcd) self.assertEqual(locked_fcd, unpickled_locked_fcd) if __name__ == '__main__': absltest.main() ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/ml_collections/config_flags/0000755162426002575230000000000000000000000025123 5ustar00mohitreddyprimarygroup00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/__init__.py0000644162426002575230000000142200000000000027233 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config flags module.""" from .config_flags import DEFINE_config_dict from .config_flags import DEFINE_config_file __all__ = ("DEFINE_config_dict", "DEFINE_config_file") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/config_flags.py0000644162426002575230000006320400000000000030123 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Complex configs commmand line parser.""" import errno import imp import os import re import sys import traceback from absl import flags from absl import logging import ml_collections from ml_collections.config_flags import tuple_parser import six from six import string_types FLAGS = flags.FLAGS # We need this to work both for Python 2 and 3. # The initial code in Python 2 was: # _FIELD_TYPE_TO_PARSER = { # types.IntType: flags.IntegerParser(), # types.FloatType: flags.FloatParser(), # OK For Python 3 # types.BooleanType: flags.BooleanParser(), # OK For Python 3 # types.StringType: flags.ArgumentParser(), # types.TupleType: tuple_parser.TupleParser(), # OK For Python 3 # } # The possible breaking changes are: # - A Python 3 int could be a Python 2 long, which was not previously supported. # We then add support for long. # - Only Python 2 str were supported (not unicode). Python 3 will behave the # same with the str semantic change. _FIELD_TYPE_TO_PARSER = { float: flags.FloatParser(), bool: flags.BooleanParser(), # Implementing a custom parser to override `Tuple` arguments. tuple: tuple_parser.TupleParser(), } for t in six.integer_types: _FIELD_TYPE_TO_PARSER[t] = flags.IntegerParser() for t in six.string_types: _FIELD_TYPE_TO_PARSER[t] = flags.ArgumentParser() _FIELD_TYPE_TO_PARSER[str] = flags.ArgumentParser() _FIELD_TYPE_TO_SERIALIZER = { t: flags.ArgumentSerializer() for t in _FIELD_TYPE_TO_PARSER } class UnsupportedOperationError(flags.Error): pass class FlagOrderError(flags.Error): pass class UnparsedFlagError(flags.Error): pass def DEFINE_config_file( # pylint: disable=g-bad-name name, default=None, help_string='path to config file.', flag_values=FLAGS, lock_config=True, **kwargs): r"""Defines flag for `ConfigDict` files compatible with absl flags. The flag's value should be a path to a valid python file which contains a function called `get_config()` that returns a python object specifying a configuration. After the flag is parsed, `FLAGS.name` will contain a reference to this object, optionally with some values overridden. During flags parsing, every flag of form `--name.([a-zA-Z0-9]+\.?)+=value` and `-name.([a-zA-Z0-9]+\.?)+ value` will be treated as an override of a specific field in the config object returned by this flag. Field is essentially a dot delimited path inside the object where each path element has to be either an attribute or a key existing in the config object. For example `--my_config.field1.field2=val` means "assign value val to the attribute (or key) `field2` inside value of the attribute (or key) `field1` inside the value of `my_config` object". If there are both attribute and key-based access with the same name, attribute is preferred. Typical usage example: `script.py`: ```python ... from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file('my_config') ... print(FLAGS.my_config) ``` `config.py`: ```python def get_config(): return { 'field1': 1, 'field2': 'tom', 'nested': { 'field': 2.23, }, } ``` The following command: ``` python script.py -- --my_config=config.py --my_config.field1 8 --my_config.nested.field=2.1 ``` will print: ``` {'field1': 8, 'field2': 'tom', 'nested': {'field': 2.1}} ``` It is possible to parameterise the get_config function, allowing it to return a differently structured result for different occasions. This is particularly useful when setting up hyperparameter sweeps across various network architectures. `parameterised_config.py`: ```python def get_config(config_string): possible_configs = { 'mlp': { 'constructor': 'snt.nets.MLP', 'config': { 'output_sizes': (128, 128, 1), } }, 'lstm': { 'constructor': 'snt.LSTM', 'config': { 'hidden_size': 128, 'forget_bias': 1.0, } } } return possible_configs[config_string] ``` If a colon is present in the command line override for the config file, everything to the right of the colon is passed into the get_config function. The following command lines will both function correctly: ``` python script.py -- --my_config=parameterised_config.py:mlp --my_config.config.output_sizes="(256,256,1)" ``` ``` python script.py -- --my_config=parameterised_config.py:lstm --my_config.config.hidden_size=256 ``` The following will produce an error, as the hidden_size flag does not exist when the "mlp" config_string is provided. ``` python script.py -- --my_config=parameterised_config.py:mlp --my_config.config.hidden_size=256 ``` Args: name: Flag name, optionally including extra config after a colon. default: Default value of the flag (default: None). help_string: Help string to display when --helpfull is called. (default: "path to config file.") flag_values: FlagValues instance used for parsing. (default: absl.flags.FLAGS) lock_config: If set to True, loaded config will be locked through calling .lock() method on its instance (if it exists). (default: True) **kwargs: Optional keyword arguments passed to Flag constructor. """ parser = _ConfigFileParser(name=name, lock_config=lock_config) flag = _ConfigFlag( parser=parser, serializer=flags.ArgumentSerializer(), name=name, default=default, help_string=help_string, flag_values=flag_values, **kwargs) # Get the module name for the frame at depth 1 in the call stack. module_name = sys._getframe(1).f_globals.get('__name__', None) # pylint: disable=protected-access module_name = sys.argv[0] if module_name == '__main__' else module_name flags.DEFINE_flag(flag, flag_values, module_name=module_name) def DEFINE_config_dict( # pylint: disable=g-bad-name name, config, help_string='ConfigDict instance.', flag_values=FLAGS, lock_config=True, **kwargs): """Defines flag for inline `ConfigDict`s compatible with absl flags. Similar to `DEFINE_config_file` except the flag's value should be a `ConfigDict` instead of a path to a file containing a `ConfigDict`. After the flag is parsed, `FLAGS.name` will contain a reference to the `ConfigDict`, optionally with some values overridden. Typical usage example: `script.py`: ```python ... from absl import flags import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict({ 'field1': 1, 'field2': 'tom', 'nested': { 'field': 2.23, } }) FLAGS = flags.FLAGS config_flags.DEFINE_config_dict('my_config', config) ... print(FLAGS.my_config) ``` The following command: ``` python script.py -- --my_config.field1 8 --my_config.nested.field=2.1 ``` will print: ``` field1: 8 field2: tom nested: {field: 2.1} ``` Args: name: Flag name. config: `ConfigDict` object. help_string: Help string to display when --helpfull is called. (default: "ConfigDict instance.") flag_values: FlagValues instance used for parsing. (default: absl.flags.FLAGS) lock_config: If set to True, loaded config will be locked through calling .lock() method on its instance (if it exists). (default: True) **kwargs: Optional keyword arguments passed to Flag constructor. """ if not isinstance(config, ml_collections.ConfigDict): raise TypeError('config should be a ConfigDict') parser = _InlineConfigParser(name=name, lock_config=lock_config) flag = _ConfigFlag( parser=parser, serializer=flags.ArgumentSerializer(), name=name, default=config, help_string=help_string, flag_values=flag_values, **kwargs) # Get the module name for the frame at depth 1 in the call stack. module_name = sys._getframe(1).f_globals.get('__name__', None) # pylint: disable=protected-access module_name = sys.argv[0] if module_name == '__main__' else module_name flags.DEFINE_flag(flag, flag_values, module_name=module_name) class _IgnoreFileNotFoundAndCollectErrors(object): """Helps recording "file not found" exceptions when loading config. Usage: ignore_errors = _IgnoreFileNotFoundAndCollectErrors() with ignore_errors.Attempt('Loading from foo', 'bar.id'): ... return True # successfully loaded from `foo` logging.error('Failed loading: {}'.format(ignore_errors.DescribeAttempts())) """ def __init__(self): self._attempts = [] # type: List[Tuple[Tuple[Text, Text], IOError]] def Attempt(self, description, path): """Creates a context manager that routes exceptions to this class.""" self._current_attempt = (description, path) ignore_errors = self class _ContextManager(object): def __enter__(self): return self def __exit__(self, exc_type, exc_value, unused_traceback): return ignore_errors.ProcessAttemptException(exc_type, exc_value) return _ContextManager() def ProcessAttemptException(self, exc_type, exc_value): expected_type = IOError if six.PY2 else FileNotFoundError # pylint: disable=undefined-variable if exc_type is expected_type and exc_value.errno == errno.ENOENT: self._attempts.append((self._current_attempt, exc_value)) # Returning a true value suppresses exceptions: # https://docs.python.org/2/reference/datamodel.html#object.__exit__ return True def DescribeAttempts(self): return '\n'.join( ' Attempted [{}]:\n {}\n {}'.format(attempt[0], attempt[1], e) for attempt, e in self._attempts) def _LoadConfigModule(name, path): """Loads a script from external file specified by path. Unprefixed path is looked for in the current working directory using regular file open operation. This should work with relative config paths. Args: name: Name of the new module. path: Path to the .py file containing the module. Returns: Module loaded from the given path. Raises: IOError: If the config file cannot be found. """ if not path: raise IOError('Path to config file is an empty string.') ignoring_errors = _IgnoreFileNotFoundAndCollectErrors() # Works for relative paths. with ignoring_errors.Attempt('Relative path', path): config_module = imp.load_source(name, path) return config_module # Nothing worked. Log the paths that were attempted. raise IOError('Failed loading config file {}\n{}'.format( name, ignoring_errors.DescribeAttempts())) class _ErrorConfig(object): """Dummy ConfigDict that raises an error on any attribute access.""" def __init__(self, error): super(_ErrorConfig, self).__init__() super(_ErrorConfig, self).__setattr__('_error', error) def __getattr__(self, attr): self._ReportError() def __setattr__(self, attr, value): self._ReportError() def __delattr__(self, attr): self._ReportError() def __getitem__(self, key): self._ReportError() def __setitem__(self, key, value): self._ReportError() def __delitem__(self, key): self._ReportError() def _ReportError(self): raise IOError('Configuration is not available because of an earlier ' 'failure to load: ' + # 'message' is not available in Python 3. getattr(self._error, 'message', str(self._error))) def _LockConfig(config): """Calls config.lock() if config has a lock method.""" if isinstance(config, _ErrorConfig): pass # Attempting to access _ErrorConfig.lock will raise its error. elif getattr(config, 'lock', None) and callable(config.lock): config.lock() else: pass # config.lock() does not have desired semantics, do nothing. class _ConfigFileParser(flags.ArgumentParser): """Parser for config files.""" def __init__(self, name, lock_config=True): self.name = name self._lock_config = lock_config def parse(self, path): """Loads a config module from `path` and returns the `get_config()` result. If a colon is present in `path`, everything to the right of the first colon is passed to `get_config` as an argument. This allows the structure of what is returned to be modified, which is useful when performing complex hyperparameter sweeps. Args: path: string, path pointing to the config file to execute. May also contain a config_string argument, e.g. be of the form "config.py:some_configuration". Returns: Result of calling `get_config` in the specified module. """ # This will be a 2 element list iff extra configuration args are present. split_path = path.split(':', 1) try: config_module = _LoadConfigModule('{}_config'.format(self.name), split_path[0]) config = config_module.get_config(*split_path[1:]) if config is None: logging.warning( '%s:get_config() returned None, did you forget a return statement?', path) except IOError as e: # Don't raise the error unless/until the config is actually accessed. config = _ErrorConfig(e) # Third party flags library catches TypeError and ValueError and rethrows, # removing useful information unless it is added here (b/63877430): except (TypeError, ValueError) as e: error_trace = traceback.format_exc() raise type(e)('Error whilst parsing config file:\n\n' + error_trace) if self._lock_config: _LockConfig(config) return config def flag_type(self): return 'config object' class _InlineConfigParser(flags.ArgumentParser): """Parser for a config defined inline (not from a file).""" def __init__(self, name, lock_config=True): self.name = name self._lock_config = lock_config def parse(self, config): if not isinstance(config, ml_collections.ConfigDict): raise TypeError('Overriding {} is not allowed.'.format(self.name)) if self._lock_config: _LockConfig(config) return config def flag_type(self): return 'config object' class _ConfigFlag(flags.Flag): """Flag definition for command-line overridable configs.""" def __init__(self, flag_values=FLAGS, **kwargs): # Parent constructor can already call .Parse, thus additional fields # have to be set here. self.flag_values = flag_values super(_ConfigFlag, self).__init__(**kwargs) def _GetOverrides(self, argv): """Parses the command line arguments for the overrides.""" overrides = [] config_index = self._FindConfigSpecified(argv) for i, arg in enumerate(argv): if re.match(r'-{{1,2}}(no)?{}\.'.format(self.name), arg): if config_index > 0 and i < config_index: raise FlagOrderError('Found {} in argv before a value for --{} ' 'was specified'.format(arg, self.name)) arg_name = arg.split('=', 1)[0] overrides.append(arg_name.split('.', 1)[1]) return overrides def _FindConfigSpecified(self, argv): """Finds element in argv specifying the value of the config flag. Args: argv: command line arguments as a list of strings. Returns: Index in argv if found and -1 otherwise. """ for i, arg in enumerate(argv): # '-(-)config' followed by '=' or at the end of the string. if re.match(r'^-{{1,2}}{}(=|$)'.format(self.name), arg) is not None: return i return -1 def _IsConfigSpecified(self, argv): """Returns `True` if the config file is specified on the command line.""" return self._FindConfigSpecified(argv) >= 0 def _set_default(self, default): if self._IsConfigSpecified(sys.argv): self.default = default else: super(_ConfigFlag, self)._set_default(default) self.default_as_str = "'{}'".format(default) def _parse(self, argument): # Parse config config = super(_ConfigFlag, self)._parse(argument) # Get list or overrides overrides = self._GetOverrides(sys.argv) # Attach types definitions overrides_types = GetTypes(overrides, config) # Iterate over overridden fields and create valid parsers self._override_values = {} for field_path, field_type in zip(overrides, overrides_types): field_help = 'An override of {}\'s field {}'.format(self.name, field_path) field_name = '{}.{}'.format(self.name, field_path) if field_type in _FIELD_TYPE_TO_PARSER: parser = _ConfigFieldParser(_FIELD_TYPE_TO_PARSER[field_type], field_path, config, self._override_values) flags.DEFINE( parser, field_name, GetValue(field_path, config), field_help, flag_values=self.flag_values, serializer=_FIELD_TYPE_TO_SERIALIZER[field_type]) flag = self.flag_values._flags().get(field_name) # pylint: disable=protected-access flag.boolean = field_type is bool else: raise UnsupportedOperationError( "Type {} of field {} is not supported for overriding. " "Currently supported types are: {}. (Note that tuples should " "be passed as a string on the command line: flag='(a, b, c)', " "rather than flag=(a, b, c).)".format( field_type, field_name, _FIELD_TYPE_TO_PARSER.keys())) self._config_filename = argument return config @property def config_filename(self): """Returns a path to a config file. Typical usage example: `script.py`: ```python ... from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file( name='my_config', default='ml_collections/config_flags/tests/configdict_config.py', help_string='config file') ... FLAGS['my_config'].config_filename will output 'ml_collections/config_flags/tests/configdict_config.py' ``` Returns: A path to a config file. For a parameterised get_config, the config filename with the provided parameterisation is returned. Raises: UnparsedFlagError: if the flag has not been parsed. """ if not hasattr(self, '_config_filename'): raise UnparsedFlagError('The flag has not been parsed yet') return self._config_filename @property def override_values(self): """Returns a flat dictionary containing overridden values. Keys in the dictionary are dot-separated paths navigating to child items in the original configuration. For example, supppose that a `config` flag is defined and initialized to the following configuration: ```python { 'a': 1, 'nested': { 'b': 2 } } ``` and the user overrides both values using command-line flags: ``` --config.a=10 --config.nested.b=20 ``` Then `FLAGS['config'].override_values` will return: ```python { 'a': 10, 'nested.b': 20 } ``` The result can be passed to `ConfigDict.update_from_flattened_dict` to update the values in a configuration. Continuing with the example above: ```python import ml_collections config = ml_collections.ConfigDict{ 'a': 123, 'nested': { 'b': 456 } } config.update_from_flattened_dict(FLAGS['config'].override_values) print(config.a) # Prints `10`. print(config.nested.b) # Prints `20`. ``` Returns: Flat dictionary with overridden values. Raises: UnparsedFlagError: if the flag has not been parsed. """ if not hasattr(self, '_override_values'): raise UnparsedFlagError('The flag has not been parsed yet') return self._override_values def is_config_flag(flag): # pylint: disable=g-bad-name """Returns True iff `flag` is an instance of `_ConfigFlag`. External users of the library may need to check if a flag is of this type or not, particularly because ConfigFlags should be parsed before any other flags. This function allows that test to be done without making the whole class public. Args: flag: Flag object. Returns: True iff `isinstance(flag, _ConfigFlag)` is true. """ return isinstance(flag, _ConfigFlag) class _ConfigFieldParser(object): """Parser with config update after parsing. This class-based wrapper creates a new object, which uses existing parser to do actual parsing and attaches a single callback to SetValue afterwards, which is used to update a predefined path in the config object. """ def __init__(self, parser, path, config, override_values): """Creates new parser with callback, using existing one to perform parsing. Args: parser: ArgumentParser instance to wrap. path: Dot separated path in config to update with the result of parser.parse(...) config: Reference to the config object. override_values: Dictionary with override values. The 'parse' method will add the parsed value to this dictionary with key `path`. """ self._parser = parser self._path = path self._config = config self._override_values = override_values def __getattr__(self, attr): return getattr(self._parser, attr) def parse(self, argument): # pylint: disable=invalid-name value = self._parser.parse(argument) SetValue(self._path, self._config, value) self._override_values[self._path] = value return value def _ExtractIndicesFromStep(step): """Separates identifier from indexes. Example: id[1][10] -> 'id', [1, 10].""" # Map returns an iterable in Python 3, thus the additional `list`. return step.split('[', 1)[0], list(map(int, re.findall(r'\[(\d+)\]', step))) def _AccessConfig(current, field, indices): """Returns member of the current config. The field in `current` specified by the `field` parameter is indexed using the values in the `indices` indices. Example: ```python current = {'first': [0], 'second': [[1]]} _AccessConfig(current, 'first', [0]) # Returns 0 _AccessConfig(current, 'second', [0, 0]) # Returns 1 ``` Args: current: Current config. field: Field to access (string). indices: List of indices. Returns: Member of the config. Raises: IndexError: if indices are invalid. KeyError: if the field is not found. """ if isinstance(field, string_types) and hasattr(current, field): current = getattr(current, field) elif hasattr(current, '__getitem__'): current = current[field] else: raise KeyError(field) for i, index in enumerate(indices): try: current = current[index] except TypeError: msg = 'Could not index config {}'.format(field) if i > 0: msg += '[{}]'.format(']['.join(map(str, indices[:i]))) raise IndexError(msg) except IndexError: msg = 'Index [{}] not found in config {}'.format( ']['.join(map(str, indices[i:])), field) if i > 0: msg += '[{}]'.format(']['.join(map(str, indices[:i]))) raise IndexError(msg) return current def _TakeStep(current, step): field, indices = _ExtractIndicesFromStep(step) return _AccessConfig(current, field, indices) def GetValue(path, config): """Gets value of a single field.""" current = config for step in path.split('.'): current = _TakeStep(current, step) return current def GetType(path, config): """Gets type of field in config described by a dotted delimited path.""" steps = path.split('.') current = config for step in steps[:-1]: current = _TakeStep(current, step) # We need to check if there is list-type indexing in the last step. field, indices = _ExtractIndicesFromStep(steps[-1]) if indices: return type(_AccessConfig(current, field, indices)) else: # Check if current is a DM collection and hence has attribute get_type() is_dm_collection = isinstance(current, ml_collections.ConfigDict) or isinstance( current, ml_collections.FieldReference) if is_dm_collection: return current.get_type(steps[-1]) else: return type(_TakeStep(current, steps[-1])) def GetTypes(paths, config): """Gets types of fields in config described by dotted delimited paths.""" return [GetType(path, config) for path in paths] def SetValue(path, config, value): """Sets value of a single field.""" current = config steps = path.split('.') if not steps or not path: raise ValueError('Path cannot be empty') for step in steps[:-1]: current = _TakeStep(current, step) field, indices = _ExtractIndicesFromStep(steps[-1]) if indices: current = _AccessConfig(current, field, indices[:-1]) try: current[indices[-1]] = value except: raise UnsupportedOperationError( 'Config does not have a setter for {}'.format(path)) else: if hasattr(current, '__setitem__') and field in current: current[field] = value elif hasattr(current, field): setattr(current, field, value) else: raise UnsupportedOperationError( 'Config does not have a setter for {}'.format(path)) ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/ml_collections/config_flags/examples/0000755162426002575230000000000000000000000026741 5ustar00mohitreddyprimarygroup00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/examples/config.py0000644162426002575230000000163000000000000030560 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Defines a method which returns an instance of ConfigDict.""" import ml_collections def get_config(): config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) return config ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/examples/define_config_dict_basic.py0000644162426002575230000000241100000000000034234 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 r"""Example of basic DEFINE_flag_dict usage. To run this example: python define_config_dict_basic.py -- --my_config_dict.field1=8 \ --my_config_dict.nested.field=2.1 --my_config_dict.tuple='(1, 2, (1, 2))' """ from absl import app from absl import flags import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) FLAGS = flags.FLAGS config_flags.DEFINE_config_dict('my_config_dict', config) def main(_): print(FLAGS.my_config_dict) if __name__ == '__main__': app.run(main) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/examples/define_config_file_basic.py0000644162426002575230000000263600000000000034241 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 # pylint: disable=line-too-long r"""Example of basic DEFINE_flag_dict usage. To run this example with basic config file: python define_config_dict_basic.py -- \ --my_config=ml_collections/config_flags/examples/config.py \ --my_config.field1=8 --my_config.nested.field=2.1 \ --my_config.tuple='(1, 2, (1, 2))' To run this example with parameterised config file: python define_config_dict_basic.py -- \ --my_config=ml_collections/config_flags/examples/parameterised_config.py:linear \ --my_config.model_config.output_size=256' """ # pylint: enable=line-too-long from absl import app from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file('my_config') def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run(main) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/examples/examples_test.py0000644162426002575230000000303200000000000032166 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """Tests for config_flags examples. Ensures that from define_config_dict_basic, define_config_file_basic run successfully. """ from absl import flags from absl.testing import absltest from absl.testing import flagsaver from ml_collections.config_flags.examples import define_config_dict_basic from ml_collections.config_flags.examples import define_config_file_basic FLAGS = flags.FLAGS class ConfigDictExamplesTest(absltest.TestCase): def test_define_config_dict_basic(self): define_config_dict_basic.main([]) @flagsaver.flagsaver def test_define_config_file_basic(self): FLAGS.my_config = 'ml_collections/config_flags/examples/config.py' define_config_file_basic.main([]) @flagsaver.flagsaver def test_define_config_file_parameterised(self): FLAGS.my_config = 'ml_collections/config_flags/examples/parameterised_config.py:linear' define_config_file_basic.main([]) if __name__ == '__main__': absltest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/examples/parameterised_config.py0000644162426002575230000000250400000000000033466 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Defines a parameterized method which returns a config depending on input.""" import ml_collections def get_config(config_string): """Return an instance of ConfigDict depending on `config_string`.""" possible_structures = { 'linear': ml_collections.ConfigDict({ 'model_constructor': 'snt.Linear', 'model_config': ml_collections.ConfigDict({ 'output_size': 42, }) }), 'lstm': ml_collections.ConfigDict({ 'model_constructor': 'snt.LSTM', 'model_config': ml_collections.ConfigDict({ 'hidden_size': 108, }) }) } return possible_structures[config_string] ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/ml_collections/config_flags/tests/0000755162426002575230000000000000000000000026265 5ustar00mohitreddyprimarygroup00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tests/config_overriding_test.py0000644162426002575230000007763300000000000033413 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Tests for ml_collection.config_flags.""" import copy import shlex import sys from absl import flags from absl.testing import absltest from absl.testing import flagsaver from absl.testing import parameterized import ml_collections from ml_collections.config_flags import config_flags from ml_collections.config_flags.tests import fieldreference_config from ml_collections.config_flags.tests import mock_config import six _CHECK_TYPES = tuple( list(six.integer_types) + list(six.string_types) + [float, bool]) _TEST_DIRECTORY = 'ml_collections/config_flags/tests' _TEST_CONFIG_FILE = '{}/mock_config.py'.format(_TEST_DIRECTORY) # Parameters to test that config loading and overriding works with both # one and two dashes. _DASH_PARAMETERS = ( ('WithTwoDashesAndEqual', '--test_config={}'.format(_TEST_CONFIG_FILE)), ('WithTwoDashes', '--test_config {}'.format(_TEST_CONFIG_FILE)), ('WithOneDashAndEqual', '-test_config={}'.format(_TEST_CONFIG_FILE)), ('WithOneDash', '-test_config {}'.format(_TEST_CONFIG_FILE))) _CONFIGDICT_CONFIG_FILE = '{}/configdict_config.py'.format(_TEST_DIRECTORY) _IOERROR_CONFIG_FILE = '{}/ioerror_config.py'.format(_TEST_DIRECTORY) _VALUEERROR_CONFIG_FILE = '{}/valueerror_config.py'.format(_TEST_DIRECTORY) _TYPEERROR_CONFIG_FILE = '{}/typeerror_config.py'.format(_TEST_DIRECTORY) _FIELDREFERENCE_CONFIG_FILE = '{}/fieldreference_config.py'.format( _TEST_DIRECTORY) _PARAMETERISED_CONFIG_FILE = '{}/parameterised_config.py'.format( _TEST_DIRECTORY) def _parse_flags(command, default=None, config=None, lock_config=True, required=False): """Parses arguments simulating sys.argv.""" if config is not None and default is not None: raise ValueError('If config is supplied a default should not be.') # Storing copy of the old sys.argv. old_argv = list(sys.argv) # Overwriting sys.argv, as sys has a global state it gets propagated. # The module shlex is useful here because it splits the input similar to # sys.argv. For instance, string arguments are not split by space. sys.argv = shlex.split(command) # Actual parsing. values = flags.FlagValues() if config is None: config_flags.DEFINE_config_file( 'test_config', default=default, flag_values=values, lock_config=lock_config) else: config_flags.DEFINE_config_dict( 'test_config', config=config, flag_values=values, lock_config=lock_config) if required: flags.mark_flag_as_required('test_config', flag_values=values) values(sys.argv) # Going back to original values. sys.argv = old_argv return values def _get_override_flags(overrides, override_format): return ' '.join([override_format.format(path, value) for path, value in six.iteritems(overrides)]) class _ConfigFlagTestCase(object): """Base class for tests with additional asserts for comparing configs.""" def assert_subset_configs(self, config1, config2): """Checks if all atrributes/values in config1 are present in config2.""" if config1 is None: return if hasattr(config1, '__dict__'): keys = [key for key in config1.__dict__ if not key.startswith('_')] get_val = getattr elif hasattr(config1, 'keys') and callable(config1.keys): keys = [key for key in config1.keys() if not key.startswith('_')] get_val = lambda haystack, needle: haystack[needle] else: # This should not fail as it simply means we cannot iterate deeper. return for attribute in keys: if isinstance(get_val(config1, attribute), _CHECK_TYPES): self.assertEqual(get_val(config1, attribute), get_val(config2, attribute)) else: # Try to go deeper with comparison self.assert_subset_configs(get_val(config1, attribute), get_val(config2, attribute)) def assert_equal_configs(self, config1, config2): """Checks if two configs are identical.""" self.assert_subset_configs(config1, config2) self.assert_subset_configs(config2, config1) class ConfigFileFlagTest(_ConfigFlagTestCase, parameterized.TestCase): """Tests config flags library.""" @parameterized.named_parameters(*_DASH_PARAMETERS) def testLoading(self, config_flag): """Tests loading config from file.""" values = _parse_flags('./program {}'.format(config_flag), default='nonexisting.py') self.assertIn('test_config', values) self.assert_equal_configs(values.test_config, mock_config.get_config()) @parameterized.named_parameters(*_DASH_PARAMETERS) def testRequired(self, config_flag): """Tests making a config_file flag required.""" with self.assertRaises(flags.IllegalFlagValueError): _parse_flags('./program ', required=True) values = _parse_flags('./program {}'.format(config_flag), required=True) self.assertIn('test_config', values) def testDefaultLoading(self): """Tests loading config from file using default path.""" for required in [True, False]: values = _parse_flags( './program', default=_TEST_CONFIG_FILE, required=required) self.assertIn('test_config', values) self.assert_equal_configs(values.test_config, mock_config.get_config()) def testLoadingNonExistingConfigLoading(self): """Tests whether loading non existing file raises an error.""" nonexisting_file = 'nonexisting.py' # Test whether loading non existing files raises an Error with both # file loading formats i.e. with '--' and '-'. # The Error is not expected to be raised until the config dict actually # has one of its attributes accessed. values = _parse_flags( './program --test_config={}'.format(nonexisting_file), default=_TEST_CONFIG_FILE) with self.assertRaisesRegex(IOError, '.*{}.*'.format(nonexisting_file)): _ = values.test_config.a values = _parse_flags( './program -test_config {}'.format(nonexisting_file), default=_TEST_CONFIG_FILE) with self.assertRaisesRegex(IOError, '.*{}.*'.format(nonexisting_file)): _ = values.test_config.a values = _parse_flags( './program -test_config ""', default=_TEST_CONFIG_FILE) with self.assertRaisesRegex(IOError, 'empty string'): _ = values.test_config.a def testIOError(self): """Tests that IOErrors raised inside config files are reported correctly.""" values = _parse_flags('./program --test_config={}' .format(_IOERROR_CONFIG_FILE)) with self.assertRaisesRegex(IOError, 'This is an IOError'): _ = values.test_config.a def testValueError(self): """Tests that ValueErrors raised when parsing config files are passed up.""" allchars_regexp = r'[\s\S]*' error_regexp = allchars_regexp.join(['Error whilst parsing config file', 'in get_config', 'in value_error_function', 'This is a ValueError']) with self.assertRaisesRegex(flags.IllegalFlagValueError, error_regexp): _ = _parse_flags('./program --test_config={}' .format(_VALUEERROR_CONFIG_FILE)) def testTypeError(self): """Tests that TypeErrors raised when parsing config files are passed up.""" allchars_regexp = r'[\s\S]*' error_regexp = allchars_regexp.join(['Error whilst parsing config file', 'in get_config', 'in type_error_function', 'This is a TypeError']) with self.assertRaisesRegex(flags.IllegalFlagValueError, error_regexp): _ = _parse_flags('./program --test_config={}' .format(_TYPEERROR_CONFIG_FILE)) # Note: While testing the overriding of parameters, we explicitly set # '!r' in the format string for the value. This ensures the 'repr()' is # called on the argument which basically means that for string arguments, # the quotes (' ') are left intact when we format the string. @parameterized.named_parameters( ('TwoDashConfigAndOverride', '--test_config={}'.format(_TEST_CONFIG_FILE), '--test_config.{}={!r}'), ('TwoDashSpaceConfigAndOverride', '--test_config {}'.format(_TEST_CONFIG_FILE), '--test_config.{} {!r}'), ('OneDashConfigAndOverride', '-test_config {}'.format(_TEST_CONFIG_FILE), '-test_config.{} {!r}'), ('OneDashEqualConfigAndOverride', '-test_config={}'.format(_TEST_CONFIG_FILE), '-test_config.{}={!r}'), ('OneDashConfigAndTwoDashOverride', '-test_config {}'.format(_TEST_CONFIG_FILE), '--test_config.{}={!r}'), ('TwoDashConfigAndOneDashOverride', '--test_config={}'.format(_TEST_CONFIG_FILE), '-test_config.{} {!r}')) def testOverride(self, config_flag, override_format): """Tests overriding config values from command line.""" overrides = { 'integer': 1, 'float': -3, 'dict.float': 3, 'object.integer': 12, 'object.float': 123, 'object.string': 'tom', 'object.dict.integer': -2, 'object.dict.float': 3.15, 'object.dict.list[0]': 101, 'object.dict.list[2][0]': 103, 'object.list[0]': 101, 'object.list[2][0]': 103, 'object.tuple': '(1,2,(1,2))', 'object.tuple_with_spaces': '(1, 2, (1, 2))', 'object_reference.dict.string': 'marry', 'object.dict.dict.float': 123, 'object_copy.float': 111.111 } override_flags = _get_override_flags(overrides, override_format) values = _parse_flags('./program {} {}'.format(config_flag, override_flags)) test_config = mock_config.get_config() test_config.integer = overrides['integer'] test_config.float = overrides['float'] test_config.dict['float'] = overrides['dict.float'] test_config.object.integer = overrides['object.integer'] test_config.object.float = overrides['object.float'] test_config.object.string = overrides['object.string'] test_config.object.dict['integer'] = overrides['object.dict.integer'] test_config.object.dict['float'] = overrides['object.dict.float'] test_config.object.dict['list'][0] = overrides['object.dict.list[0]'] test_config.object.dict['list'][2][0] = overrides['object.dict.list[2][0]'] test_config.object.dict['list'][0] = overrides['object.list[0]'] test_config.object.dict['list'][2][0] = overrides['object.list[2][0]'] test_config.object.tuple = (1, 2, (1, 2)) test_config.object.tuple_with_spaces = (1, 2, (1, 2)) test_config.object_reference.dict['string'] = overrides[ 'object_reference.dict.string'] test_config.object.dict['dict']['float'] = overrides[ 'object.dict.dict.float'] test_config.object_copy.float = overrides['object_copy.float'] self.assert_equal_configs(values.test_config, test_config) def testOverrideBoolean(self): """Tests overriding boolean config values from command line.""" prefix = './program --test_config={}'.format(_TEST_CONFIG_FILE) # The default for dict.bool is False. values = _parse_flags('{} --test_config.dict.bool'.format(prefix)) self.assertTrue(values.test_config.dict['bool']) values = _parse_flags('{} --test_config.dict.bool=true'.format(prefix)) self.assertTrue(values.test_config.dict['bool']) # The default for object.bool is True. values = _parse_flags('{} --test_config.object.bool=false'.format(prefix)) self.assertFalse(values.test_config.object.bool) values = _parse_flags('{} --notest_config.object.bool'.format(prefix)) self.assertFalse(values.test_config.object.bool) def testFieldReferenceOverride(self): """Tests whether types of FieldReference fields are valid.""" overrides = {'ref_nodefault': 1, 'ref': 2} override_flags = _get_override_flags(overrides, '--test_config.{}={!r}') config_flag = '--test_config={}'.format(_FIELDREFERENCE_CONFIG_FILE) values = _parse_flags('./program {} {}'.format(config_flag, override_flags)) cfg = values.test_config self.assertEqual(cfg.ref_nodefault, overrides['ref_nodefault']) self.assertEqual(cfg.ref, overrides['ref']) @parameterized.named_parameters(*_DASH_PARAMETERS) def testSetNotExistingKey(self, config_flag): """Tests setting value of not existing key.""" with self.assertRaises(KeyError): _parse_flags('./program {} ' '--test_config.not_existing_key=1 '.format(config_flag)) @parameterized.named_parameters(*_DASH_PARAMETERS) def testSetReadOnlyField(self, config_flag): """Tests setting value of key which is read only.""" with self.assertRaises(AttributeError): _parse_flags('./program {} ' '--test_config.readonly_field=1 '.format(config_flag)) @parameterized.named_parameters(*_DASH_PARAMETERS) def testNotSupportedOperation(self, config_flag): """Tests setting value to not supported type.""" with self.assertRaises(config_flags.UnsupportedOperationError): _parse_flags('./program {} ' '--test_config.list=[1]'.format(config_flag)) def testReadingNonExistingKey(self): """Tests reading non existing key from config.""" test_config = mock_config.get_config() with self.assertRaises(config_flags.UnsupportedOperationError): config_flags.SetValue('dict.not_existing_key', test_config, 1) def testReadingSettingExistingKeyInDict(self): """Tests setting non existing key from dict inside config.""" test_config = mock_config.get_config() with self.assertRaises(KeyError): config_flags.SetValue('dict.not_existing_key.key', test_config, 1) def testEmptyKey(self): """Tests calling an empty key update.""" test_config = mock_config.get_config() with self.assertRaises(ValueError): config_flags.SetValue('', test_config, None) def testListExtraIndex(self): """Tries to index a non-indexable list element.""" test_config = mock_config.get_config() with self.assertRaises(IndexError): config_flags.GetValue('dict.list[0][0]', test_config) def testListOutOfRangeGet(self): """Tries to access out-of-range value in list.""" test_config = mock_config.get_config() with self.assertRaises(IndexError): config_flags.GetValue('dict.list[2][1]', test_config) def testListOutOfRangeSet(self): """Tries to override out-of-range value in list.""" test_config = mock_config.get_config() with self.assertRaises(config_flags.UnsupportedOperationError): config_flags.SetValue('dict.list[2][1]', test_config, -1) def testParserWrapping(self): """Tests callback based Parser wrapping.""" parser = flags.IntegerParser() test_config = mock_config.get_config() overrides = {} wrapped_parser = config_flags._ConfigFieldParser(parser, 'integer', test_config, overrides) wrapped_parser.parse('12321') self.assertEqual(test_config.integer, 12321) self.assertEqual(overrides, {'integer': 12321}) self.assertEqual(wrapped_parser.flag_type(), parser.flag_type()) self.assertEqual(wrapped_parser.syntactic_help, parser.syntactic_help) self.assertEqual(wrapped_parser.convert('3'), parser.convert('3')) def testTypes(self): """Tests whether various types of objects are valid.""" parser = config_flags._ConfigFileParser('test_config') self.assertEqual(parser.flag_type(), 'config object') test_config = mock_config.get_config() paths = ( 'float', 'integer', 'string', 'bool', 'dict', 'dict.float', 'dict.list', 'list', 'list[0]', 'object.float', 'object.integer', 'object.string', 'object.bool', 'object.dict', 'object.dict.float', 'object.dict.list', 'object.list', 'object.list[0]', 'object.tuple', 'object_reference.float', 'object_reference.integer', 'object_reference.string', 'object_reference.bool', 'object_reference.dict', 'object_reference.dict.float', 'object_copy.float', 'object_copy.integer', 'object_copy.string', 'object_copy.bool', 'object_copy.dict', 'object_copy.dict.float' ) paths_types = [ float, int, str, bool, dict, float, list, list, int, float, int, str, bool, dict, float, list, list, int, tuple, float, int, str, bool, dict, float, float, int, str, bool, dict, float, ] config_types = config_flags.GetTypes(paths, test_config) self.assertEqual(paths_types, config_types) def testFieldReferenceTypes(self): """Tests whether types of FieldReference fields are valid.""" test_config = fieldreference_config.get_config() paths = ['ref_nodefault', 'ref'] paths_types = [int, int] config_types = config_flags.GetTypes(paths, test_config) self.assertEqual(paths_types, config_types) @parameterized.named_parameters( ('WithTwoDashesAndEqual', '--test_config=config.py'), ('WithTwoDashes', '--test_config'), ('WithOneDashAndEqual', '-test_config=config.py'), ('WithOneDash', '-test_config')) def testConfigSpecified(self, config_argument): """Tests whether config is specified on the command line.""" config_flag = config_flags._ConfigFlag( parser=flags.ArgumentParser(), serializer=None, name='test_config', default='defaultconfig.py', help_string='' ) self.assertTrue(config_flag._IsConfigSpecified([config_argument])) self.assertFalse(config_flag._IsConfigSpecified([''])) def testFindConfigSpecified(self): """Tests whether config is specified on the command line.""" config_flag = config_flags._ConfigFlag( parser=flags.ArgumentParser(), serializer=None, name='test_config', default='defaultconfig.py', help_string='' ) self.assertEqual(config_flag._FindConfigSpecified(['']), -1) argv_length = 20 for i in range(argv_length): # Generate list of '--test_config.i=0' args. argv = ['--test_config.{}=0'.format(arg) for arg in range(argv_length)] self.assertEqual(config_flag._FindConfigSpecified(argv), -1) # Override i-th arg with something specifying the value of 'test_config'. # After doing this, _FindConfigSpecified should return the value of i. argv[i] = '--test_config' self.assertEqual(config_flag._FindConfigSpecified(argv), i) argv[i] = '--test_config=config.py' self.assertEqual(config_flag._FindConfigSpecified(argv), i) argv[i] = '-test_config' self.assertEqual(config_flag._FindConfigSpecified(argv), i) argv[i] = '-test_config=config.py' self.assertEqual(config_flag._FindConfigSpecified(argv), i) def testLoadingLockedConfigDict(self): """Tests loading ConfigDict instance and that it is locked.""" config_flag = '--test_config={}'.format(_CONFIGDICT_CONFIG_FILE) values = _parse_flags('./program {}'.format(config_flag), lock_config=True) self.assertTrue(values.test_config.is_locked) self.assertTrue(values.test_config.nested_configdict.is_locked) values = _parse_flags('./program {}'.format(config_flag), lock_config=False) self.assertFalse(values.test_config.is_locked) self.assertFalse(values.test_config.nested_configdict.is_locked) @parameterized.named_parameters( ('WithTwoDashesAndEqual', '--test_config={}'.format(_TEST_DIRECTORY)), ('WithTwoDashes', '--test_config {}'.format(_TEST_DIRECTORY)), ('WithOneDashAndEqual', '-test_config={}'.format(_TEST_DIRECTORY)), ('WithOneDash', '-test_config {}'.format(_TEST_DIRECTORY))) def testPriorityOfFieldLookup(self, config_flag): """Tests whether attributes have higher priority than key-based lookup.""" values = _parse_flags('./program {}/mini_config.py'.format(config_flag), lock_config=False) self.assertTrue( config_flags.GetValue('entry_with_collision', values.test_config)) @parameterized.named_parameters( ('TypeAEnabled', ':type_a', {'thing_a': 23, 'thing_b': 42}, ['thing_c']), ('TypeASpecify', ':type_a --test_config.thing_a=24', {'thing_a': 24}, []), ('TypeBEnabled', ':type_b', {'thing_a': 19, 'thing_c': 65}, ['thing_b'])) def testExtraConfigString(self, flag_override, should_exist, should_error): """Tests with the config_string argument is used properly.""" values = _parse_flags( './program --test_config={}/parameterised_config.py{}'.format( _TEST_DIRECTORY, flag_override)) # Ensure that the values exist in the ConfigDict, with desired values. for subfield_name, expected_value in six.iteritems(should_exist): self.assertEqual(values.test_config[subfield_name], expected_value) # Ensure the values which should not be part of the ConfigDict are really # not there. for should_error_name in should_error: with self.assertRaisesRegex(KeyError, 'Did you mean'): _ = values.test_config[should_error_name] def testExtraConfigInvalidFlag(self): with self.assertRaisesRegex(AttributeError, 'not_valid_item'): _parse_flags( ('./program --test_config={}/parameterised_config.py:type_a ' '--test_config.not_valid_item=42').format(_TEST_DIRECTORY)) def testOverridingConfigDict(self): """Tests overriding of ConfigDict fields.""" config_flag = '--test_config={}'.format(_CONFIGDICT_CONFIG_FILE) overrides = { 'integer': 2, 'reference': 2, 'list[0]': 5, 'nested_list[0][0]': 5, 'nested_configdict.integer': 5, 'unusable_config.dummy_attribute': 5 } override_flags = _get_override_flags(overrides, '--test_config.{}={}') values = _parse_flags('./program {} {}'.format(config_flag, override_flags)) self.assertEqual(values.test_config.integer, overrides['integer']) self.assertEqual(values.test_config.reference, overrides['reference']) self.assertEqual(values.test_config.list[0], overrides['list[0]']) self.assertEqual(values.test_config.nested_list[0][0], overrides['nested_list[0][0]']) self.assertEqual(values.test_config.nested_configdict.integer, overrides['nested_configdict.integer']) self.assertEqual(values.test_config.unusable_config.dummy_attribute, overrides['unusable_config.dummy_attribute']) # Attribute error. overrides = {'nonexistent': 'value'} with self.assertRaises(AttributeError): override_flags = _get_override_flags(overrides, '--test_config.{}={}') _parse_flags('./program {} {}'.format(config_flag, override_flags)) # "Did you mean" messages. overrides = {'integre': 2} with self.assertRaisesRegex(AttributeError, 'Did you.*integer.*'): override_flags = _get_override_flags(overrides, '--test_config.{}={}') _parse_flags('./program {} {}'.format(config_flag, override_flags)) overrides = {'referecne': 2} with self.assertRaisesRegex(AttributeError, 'Did you.*reference.*'): override_flags = _get_override_flags(overrides, '--test_config.{}={}') _parse_flags('./program {} {}'.format(config_flag, override_flags)) # This test adds new flags, so use FlagSaver to make it hermetic. @flagsaver.flagsaver def testIsConfigFile(self): config_flags.DEFINE_config_file('is_a_config_flag') flags.DEFINE_integer('not_a_config_flag', -1, '') self.assertTrue( config_flags.is_config_flag(flags.FLAGS['is_a_config_flag'])) self.assertFalse( config_flags.is_config_flag(flags.FLAGS['not_a_config_flag'])) # This test adds new flags, so use FlagSaver to make it hermetic. @flagsaver.flagsaver def testModuleName(self): config_flags.DEFINE_config_file('flag') argv_0 = './program' _parse_flags(argv_0) self.assertIn(flags.FLAGS['flag'], flags.FLAGS.flags_by_module_dict()[argv_0]) def testFlagOrder(self): with self.assertRaisesWithLiteralMatch( config_flags.FlagOrderError, ('Found --test_config.int=1 in argv before a value for --test_config ' 'was specified')): _parse_flags('./program --test_config.int=1 ' '--test_config={}'.format(_TEST_CONFIG_FILE)) @flagsaver.flagsaver def testOverrideValues(self): config_flags.DEFINE_config_file('config') with self.assertRaisesWithLiteralMatch(config_flags.UnparsedFlagError, 'The flag has not been parsed yet'): flags.FLAGS['config'].override_values # pylint: disable=pointless-statement original_float = -1.0 original_dictfloat = -2.0 config = ml_collections.ConfigDict({ 'integer': -1, 'float': original_float, 'dict': { 'float': original_dictfloat } }) integer_override = 0 dictfloat_override = 1.1 values = _parse_flags('./program --test_config={} --test_config.integer={} ' '--test_config.dict.float={}'.format( _TEST_CONFIG_FILE, integer_override, dictfloat_override)) config.update_from_flattened_dict(values['test_config'].override_values) self.assertEqual(config['integer'], integer_override) self.assertEqual(config['float'], original_float) self.assertEqual(config['dict']['float'], dictfloat_override) @parameterized.named_parameters( ('ConfigFile1', _TEST_CONFIG_FILE), ('ConfigFile2', _CONFIGDICT_CONFIG_FILE), ('ParameterisedConfigFile', _PARAMETERISED_CONFIG_FILE + ':type_a'), ) def testConfigPath(self, config_file): """Test access to saved config file path.""" values = _parse_flags('./program --test_config={}'.format(config_file)) self.assertEqual(values['test_config'].config_filename, config_file) def _simple_config(): config = ml_collections.ConfigDict() config.foo = 3 return config class ConfigDictFlagTest(_ConfigFlagTestCase, absltest.TestCase): """Tests DEFINE_config_dict. DEFINE_config_dict reuses a lot of code in DEFINE_config_file so the tests here are mostly sanity checks. """ def testBasicUsage(self): values = _parse_flags('./program', config=_simple_config()) self.assertIn('test_config', values) self.assert_equal_configs(_simple_config(), values.test_config) def testChangingLockedConfigRaisesAnError(self): values = _parse_flags( './program', config=_simple_config(), lock_config=True) with self.assertRaisesRegex(AttributeError, 'config is locked'): values.test_config.new_foo = 20 def testChangingUnlockedConfig(self): values = _parse_flags( './program', config=_simple_config(), lock_config=False) values.test_config.new_foo = 20 def testNonConfigDictAsConfig(self): non_config_dict = dict(a=1, b=2) with self.assertRaisesRegex(TypeError, 'should be a ConfigDict'): _parse_flags('./program', config=non_config_dict) def testOverridingAttribute(self): new_foo = 10 values = _parse_flags( './program --test_config.foo={}'.format(new_foo), config=_simple_config()) self.assertNotEqual(new_foo, _simple_config().foo) self.assertEqual(new_foo, values.test_config.foo) def testOverridingMainConfigFlagRaisesAnError(self): with self.assertRaisesRegex(flags.IllegalFlagValueError, 'Overriding test_config is not allowed'): _parse_flags('./program --test_config=bad_input', config=_simple_config()) def testOverridesSerialize(self): all_types_config = ml_collections.ConfigDict() all_types_config.type_bool = False all_types_config.type_bytes = b'bytes' all_types_config.type_float = 1.0 all_types_config.type_int = 1 all_types_config.type_str = 'str' all_types_config.type_ustr = u'ustr' all_types_config.type_tuple = (False, b'bbytes', 1.0, 1, 'str', u'ustr',) # Change values via the command line: command_line = ( './program' ' --test_config.type_bool=True' ' --test_config.type_float=10' ' --test_config.type_int=10' ' --test_config.type_str=str_commandline' ' --test_config.type_tuple="(\'tuple_str\', 10)"' ) if six.PY3: # Not supported in py2; type_bytes is never supported. command_line += ' --test_config.type_ustr=ustr_commandline' values = _parse_flags(command_line, config=copy.copy(all_types_config)) # Check we get the expected values (ie the ones defined above not the ones # defined in the config itself) get_parser = lambda name: values._flags()[name].parser.parse get_serializer = lambda name: values._flags()[name].serializer.serialize serialize_parse = ( lambda name, value: get_parser(name)(get_serializer(name)(value))) with self.subTest('bool'): self.assertNotEqual(values.test_config.type_bool, all_types_config['type_bool']) self.assertEqual(values.test_config.type_bool, True) self.assertEqual(values.test_config.type_bool, serialize_parse('test_config.type_bool', values.test_config.type_bool)) with self.subTest('float'): self.assertNotEqual(values.test_config.type_float, all_types_config['type_float']) self.assertEqual(values.test_config.type_float, 10.) self.assertEqual(values.test_config.type_float, serialize_parse('test_config.type_float', values.test_config.type_float)) with self.subTest('int'): self.assertNotEqual(values.test_config.type_int, all_types_config['type_int']) self.assertEqual(values.test_config.type_int, 10) self.assertEqual(values.test_config.type_int, serialize_parse('test_config.type_int', values.test_config.type_int)) with self.subTest('str'): self.assertNotEqual(values.test_config.type_str, all_types_config['type_str']) self.assertEqual(values.test_config.type_str, 'str_commandline') self.assertEqual(values.test_config.type_str, serialize_parse('test_config.type_str', values.test_config.type_str)) with self.subTest('ustr'): if six.PY3: self.assertNotEqual(values.test_config.type_ustr, all_types_config['type_ustr']) self.assertEqual(values.test_config.type_ustr, u'ustr_commandline') self.assertEqual(values.test_config.type_ustr, serialize_parse('test_config.type_ustr', values.test_config.type_ustr)) with self.subTest('tuple'): self.assertNotEqual(values.test_config.type_tuple, all_types_config['type_tuple']) self.assertEqual(values.test_config.type_tuple, ('tuple_str', 10)) self.assertEqual(values.test_config.type_tuple, serialize_parse('test_config.type_tuple', values.test_config.type_tuple)) def main(): absltest.main() if __name__ == '__main__': main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tests/configdict_config.py0000644162426002575230000000635000000000000032301 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """ConfigDict config file.""" import ml_collections class UnusableConfig(object): """Test against code assuming the semantics of attributes (such as `lock`). This class is to test that the flags implementation does not assume the semantics of attributes. This is to avoid code such as: ```python if hasattr(obj, lock): obj.lock() ``` which will fail if `obj` has an attribute `lock` that does not behave in the way we expect. This class only has unusable attributes. There are two exceptions for which this class behaves normally: * Python's special functions which start and end with a double underscore. * `dummy_attribute`, an attribute used to test the class. For other attributes, both `hasttr(obj, attr)` and `callable(obj, attr)` will return True. Calling `obj.attr` will return a function which takes no arguments and raises an AttributeError when called. For example, the `lock` example above will raise an AttributeError. The only valid action on attributes is assignment, e.g. ```python obj = UnusableConfig() obj.attr = 1 ``` In which case the attribute will keep its assigned value and become usable. """ def __init__(self): self._dummy_attribute = 1 def __getattr__(self, attribute): """Get an arbitrary attribute. Returns a function which takes no arguments and raises an AttributeError, except for Python special functions in which case an AttributeError is directly raised. Args: attribute: A string representing the attribute's name. Returns: A function which raises an AttributeError when called. Raises: AttributeError: when the attribute is a Python special function starting and ending with a double underscore. """ if attribute.startswith("__") and attribute.endswith("__"): raise AttributeError("UnusableConfig does not contain entry {}.". format(attribute)) def raise_attribute_error_fun(): raise AttributeError( "{} is not a usable attribute of UnusableConfig".format( attribute)) return raise_attribute_error_fun @property def dummy_attribute(self): return self._dummy_attribute @dummy_attribute.setter def dummy_attribute(self, value): self._dummy_attribute = value def get_config(): """Returns a ConfigDict. Used for tests.""" cfg = ml_collections.ConfigDict() cfg.integer = 1 cfg.reference = ml_collections.FieldReference(1) cfg.list = [1, 2, 3] cfg.nested_list = [[1, 2, 3]] cfg.nested_configdict = ml_collections.ConfigDict() cfg.nested_configdict.integer = 1 cfg.unusable_config = UnusableConfig() return cfg ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tests/fieldreference_config.py0000644162426002575230000000157600000000000033137 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file with field references.""" import ml_collections from ml_collections.config_dict import config_dict def get_config(): cfg = ml_collections.ConfigDict() cfg.ref = ml_collections.FieldReference(123) cfg.ref_nodefault = config_dict.placeholder(int) return cfg ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tests/ioerror_config.py0000644162426002575230000000212200000000000031642 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file that raises IOError on import. The flags library tries to load configuration files in a few different ways. For this it relies on catching IOError exceptions of the type "File not found" and ignoring them to continue trying with a different loading method. But we need to ensure that other types of IOError exceptions are propagated correctly (b/63165566). This is tested in `ConfigFlagTest.testIOError` in `config_overriding_test.py`. """ raise IOError('This is an IOError.') ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tests/mini_config.py0000644162426002575230000000205300000000000031120 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Dummy Config file.""" class MiniConfig(object): """Just a dummy config.""" def __init__(self): self.dict = {} self.field = False def __getitem__(self, key): return self.dict[key] def __contains__(self, key): return key in self.dict def __setitem__(self, key, value): self.dict[key] = value def get_config(): cfg = MiniConfig() cfg['entry_with_collision'] = False cfg.entry_with_collision = True return cfg ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tests/mock_config.py0000644162426002575230000000260700000000000031122 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Dummy Config file.""" import copy class TestConfig(object): """Just a dummy config.""" def __init__(self): self.integer = 23 self.float = 2.34 self.string = 'james' self.bool = True self.dict = { 'integer': 1, 'float': 3.14, 'string': 'mark', 'bool': False, 'dict': { 'float': 5. }, 'list': [1, 2, [3]] } self.list = [1, 2, [3]] self.tuple = (1, 2, (3,)) self.tuple_with_spaces = (1, 2, (3,)) @property def readonly_field(self): return 42 def __repr__(self): return str(self.__dict__) def get_config(): config = TestConfig() config.object = TestConfig() config.object_reference = config.object config.object_copy = copy.deepcopy(config.object) return config ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tests/parameterised_config.py0000644162426002575230000000204500000000000033012 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file where `get_config` takes a string argument.""" import ml_collections def get_config(config_string): """A config which takes an extra string argument.""" possible_configs = { 'type_a': ml_collections.ConfigDict({ 'thing_a': 23, 'thing_b': 42, }), 'type_b': ml_collections.ConfigDict({ 'thing_a': 19, 'thing_c': 65, }), } return possible_configs[config_string] ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tests/typeerror_config.py0000644162426002575230000000223600000000000032222 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file that raises TypeError on import. When trying loading the configuration file as a flag, the flags library catches TypeError exceptions then recasts them as a IllegalFlagTypeError and rethrows (b/63877430). The rethrow does not include the stacktrace from the original exception, so we manually add the stracktrace in configflags.parse(). This is tested in `ConfigFlagTest.testTypeError` in `config_overriding_test.py`. """ def type_error_function(): raise TypeError('This is a TypeError.') def get_config(): return {'item': type_error_function()} ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tests/valueerror_config.py0000644162426002575230000000224600000000000032356 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Config file that raises ValueError on import. When trying loading the configuration file as a flag, the flags library catches ValueError exceptions then recasts them as a IllegalFlagValueError and rethrows (b/63877430). The rethrow does not include the stacktrace from the original exception, so we manually add the stracktrace in configflags.parse(). This is tested in `ConfigFlagTest.testValueError` in `config_overriding_test.py`. """ def value_error_function(): raise ValueError('This is a ValueError.') def get_config(): return {'item': value_error_function()} ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/ml_collections/config_flags/tuple_parser.py0000644162426002575230000000662500000000000030213 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Custom parser to override tuples in the config dict.""" import ast from absl import flags class TupleParser(flags.ArgumentParser): """Parser for tuple arguments. Custom flag parser for Tuple objects that is based on the existing parsers in `absl.flags`. This parser can be used to read in and override tuple arguments. It outputs a `tuple` object from an existing `tuple` or `str`. The ony requirement is that the overriding parameter should be a `tuple` as well. The overriding parameter can have a different number of elements of different type than the original. For a detailed list of what `str` arguments are supported for overriding, look at `ast.literal_eval` from the Python Standard Library. """ def parse(self, argument): """Returns a `tuple` representing the input `argument`. Args: argument: The argument to be parsed. Valid types are `tuple` and `str`. An empty `tuple` is returned for arguments `NoneType`. Returns: A `TupleType` representing the input argument as a `tuple`. Raises: `TypeError`: If the argument is not of type `tuple`, `str`, or `NoneType`. `ValueError`: If the string is not a well formed `tuple`. """ if isinstance(argument, tuple): return argument elif isinstance(argument, str): return _convert_str_to_tuple(argument) elif argument is None: return () else: msg = ('Could not parse argument {} of type ' '{} for element of type `tuple`.' ).format(argument, type(argument)) raise TypeError(msg) def flag_type(self): return 'tuple' def _convert_str_to_tuple(string): """Function to convert a Python `str` object to a `tuple`. Args: string: The `str` to be converted. Returns: A `tuple` version of the string. Raises: ValueError: If the string is not a well formed `tuple`. """ # literal_eval converts strings to int, tuple, list, float and dict, # booleans and None. It can also handle nested tuples. # It does not, however, handle elements of type set. try: value = ast.literal_eval(string) except ValueError: # A ValueError is raised by literal_eval if the string is not well # formed. Catch it and print out a more readable statement. msg = 'Argument {} does not evaluate to a `tuple` object.'.format(string) raise ValueError(msg) except SyntaxError: # The only other error that may be raised is a `SyntaxError` because # `literal_eval` calls the Python in-built `compile`. This error is # caused by parsing issues. msg = 'Error while parsing string: {}'.format(string) raise ValueError(msg) # Make sure we got a tuple. If not, its an error. if isinstance(value, tuple): return value else: raise ValueError('Expected a tuple argument, got {}'.format(type(value))) ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/ml_collections.egg-info/0000755162426002575230000000000000000000000024174 5ustar00mohitreddyprimarygroup00000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243892.0 ml_collections-0.1.0/ml_collections.egg-info/PKG-INFO0000644162426002575230000005523100000000000025277 0ustar00mohitreddyprimarygroup00000000000000Metadata-Version: 2.1 Name: ml-collections Version: 0.1.0 Summary: ML Collections is a library of Python collections designed for ML usecases. Home-page: https://github.com/google/ml_collections Author: ML Collections Authors Author-email: ml-collections@google.com License: Apache 2.0 Description: # ML Collections ML Collections is a library of Python Collections designed for ML use cases. ## ConfigDict The two classes called `ConfigDict` and `FrozenConfigDict` are "dict-like" data structures with dot access to nested elements. Together, they are supposed to be used as a main way of expressing configurations of experiments and models. This document describes example usage of `ConfigDict`, `FrozenConfigDict`, `FieldReference`. ### Features * Dot-based access to fields. * Locking mechanism to prevent spelling mistakes. * Lazy computation. * FrozenConfigDict() class which is immutable and hashable. * Type safety. * "Did you mean" functionality. * Human readable printing (with valid references and cycles), using valid YAML format. * Fields can be passed as keyword arguments using the `**` operator. * There are two exceptions to the strong type-safety of the ConfigDict. `int` values can be passed in to fields of type `float`. In such a case, the value is type-converted to a `float` before being stored. Similarly, all string types (including Unicode strings) can be stored in fields of type `str` or `unicode`. ### Basic Usage ```python import ml_collections cfg = ml_collections.ConfigDict() cfg.float_field = 12.6 cfg.integer_field = 123 cfg.another_integer_field = 234 cfg.nested = ml_collections.ConfigDict() cfg.nested.string_field = 'tom' print(cfg.integer_field) # Prints 123. print(cfg['integer_field']) # Prints 123 as well. try: cfg.integer_field = 'tom' # Raises TypeError as this field is an integer. except TypeError as e: print(e) cfg.float_field = 12 # Works: `Int` types can be assigned to `Float`. cfg.nested.string_field = u'bob' # `String` fields can store Unicode strings. print(cfg) ``` ### FrozenConfigDict A `FrozenConfigDict`is an immutable, hashable type of `ConfigDict`: ```python import ml_collections initial_dictionary = { 'int': 1, 'list': [1, 2], 'tuple': (1, 2, 3), 'set': {1, 2, 3, 4}, 'dict_tuple_list': {'tuple_list': ([1, 2], 3)} } cfg = ml_collections.ConfigDict(initial_dictionary) frozen_dict = ml_collections.FrozenConfigDict(initial_dictionary) print(frozen_dict.tuple) # Prints tuple (1, 2, 3) print(frozen_dict.list) # Prints tuple (1, 2) print(frozen_dict.set) # Prints frozenset {1, 2, 3, 4} print(frozen_dict.dict_tuple_list.tuple_list[0]) # Prints tuple (1, 2) frozen_cfg = ml_collections.FrozenConfigDict(cfg) print(frozen_cfg == frozen_dict) # True print(hash(frozen_cfg) == hash(frozen_dict)) # True try: frozen_dict.int = 2 # Raises TypeError as FrozenConfigDict is immutable. except AttributeError as e: print(e) # Converting between `FrozenConfigDict` and `ConfigDict`: thawed_frozen_cfg = ml_collections.ConfigDict(frozen_dict) print(thawed_frozen_cfg == cfg) # True frozen_cfg_to_cfg = frozen_dict.as_configdict() print(frozen_cfg_to_cfg == cfg) # True ``` ### FieldReferences and placeholders A `FieldReference` is useful for having multiple fields use the same value. It can also be used for [lazy computation](#lazy-computation). You can use `placeholder()` as a shortcut to create a `FieldReference` (field) with a `None` default value. This is useful if a program uses optional configuration fields. ```python import ml_collections from ml_collections.config_dict import config_dict placeholder = ml_collections.FieldReference(0) cfg = ml_collections.ConfigDict() cfg.placeholder = placeholder cfg.optional = config_dict.placeholder(int) cfg.nested = ml_collections.ConfigDict() cfg.nested.placeholder = placeholder try: cfg.optional = 'tom' # Raises Type error as this field is an integer. except TypeError as e: print(e) cfg.optional = 1555 # Works fine. cfg.placeholder = 1 # Changes the value of both placeholder and # nested.placeholder fields. print(cfg) ``` Note that the indirection provided by `FieldReference`s will be lost if accessed through a `ConfigDict`. ```python import ml_collections placeholder = ml_collections.FieldReference(0) cfg.field1 = placeholder cfg.field2 = placeholder # This field will be tied to cfg.field1. cfg.field3 = cfg.field1 # This will just be an int field initialized to 0. ``` ### Lazy computation Using a `FieldReference` in a standard operation (addition, subtraction, multiplication, etc...) will return another `FieldReference` that points to the original's value. You can use `FieldReference.get()` to execute the operations and get the reference's computed value, and `FieldReference.set()` to change the original reference's value. ```python import ml_collections ref = ml_collections.FieldReference(1) print(ref.get()) # Prints 1 add_ten = ref.get() + 10 # ref.get() is an integer and so is add_ten add_ten_lazy = ref + 10 # add_ten_lazy is a FieldReference - NOT an integer print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 11 because ref's value is 1 # Addition is lazily computed for FieldReferences so changing ref will change # the value that is used to compute add_ten. ref.set(5) print(add_ten) # Prints 11 print(add_ten_lazy.get()) # Prints 15 because ref's value is 5 ``` If a `FieldReference` has `None` as its original value, or any operation has an argument of `None`, then the lazy computation will evaluate to `None`. We can also use fields in a `ConfigDict` in lazy computation. In this case a field will only be lazily evaluated if `ConfigDict.get_ref()` is used to get it. ```python import ml_collections config = ml_collections.ConfigDict() config.reference_field = ml_collections.FieldReference(1) config.integer_field = 2 config.float_field = 2.5 # No lazy evaluatuations because we didn't use get_ref() config.no_lazy = config.integer_field * config.float_field # This will lazily evaluate ONLY config.integer_field config.lazy_integer = config.get_ref('integer_field') * config.float_field # This will lazily evaluate ONLY config.float_field config.lazy_float = config.integer_field * config.get_ref('float_field') # This will lazily evaluate BOTH config.integer_field and config.float_Field config.lazy_both = (config.get_ref('integer_field') * config.get_ref('float_field')) config.integer_field = 3 print(config.no_lazy) # Prints 5.0 - It uses integer_field's original value print(config.lazy_integer) # Prints 7.5 config.float_field = 3.5 print(config.lazy_float) # Prints 7.0 print(config.lazy_both) # Prints 10.5 ``` #### Changing lazily computed values Lazily computed values in a ConfigDict can be overridden in the same way as regular values. The reference to the `FieldReference` used for the lazy computation will be lost and all computations downstream in the reference graph will use the new value. ```python import ml_collections config = ml_collections.ConfigDict() config.reference = 1 config.reference_0 = config.get_ref('reference') + 10 config.reference_1 = config.get_ref('reference') + 20 config.reference_1_0 = config.get_ref('reference_1') + 100 print(config.reference) # Prints 1. print(config.reference_0) # Prints 11. print(config.reference_1) # Prints 21. print(config.reference_1_0) # Prints 121. config.reference_1 = 30 print(config.reference) # Prints 1 (unchanged). print(config.reference_0) # Prints 11 (unchanged). print(config.reference_1) # Prints 30. print(config.reference_1_0) # Prints 130. ``` #### Cycles You cannot create cycles using references. Fortunately [the only way](#changing-lazily-computed-values) to create a cycle is by assigning a computed field to one that *is not* the result of computation. This is forbidden: ```python import ml_collections from ml_collections.config_dict import config_dict config = ml_collections.ConfigDict() config.integer_field = 1 config.bigger_integer_field = config.get_ref('integer_field') + 10 try: # Raises a MutabilityError because setting config.integer_field would # cause a cycle. config.integer_field = config.get_ref('bigger_integer_field') + 2 except config_dict.MutabilityError as e: print(e) ``` ### Advanced usage Here are some more advanced examples showing lazy computation with different operators and data types. ```python import ml_collections config = ml_collections.ConfigDict() config.float_field = 12.6 config.integer_field = 123 config.list_field = [0, 1, 2] config.float_multiply_field = config.get_ref('float_field') * 3 print(config.float_multiply_field) # Prints 37.8 config.float_field = 10.0 print(config.float_multiply_field) # Prints 30.0 config.longer_list_field = config.get_ref('list_field') + [3, 4, 5] print(config.longer_list_field) # Prints [0, 1, 2, 3, 4, 5] config.list_field = [-1] print(config.longer_list_field) # Prints [-1, 3, 4, 5] # Both operands can be references config.ref_subtraction = ( config.get_ref('float_field') - config.get_ref('integer_field')) print(config.ref_subtraction) # Prints -113.0 config.integer_field = 10 print(config.ref_subtraction) # Prints 0.0 ``` ### Equality checking You can use `==` and `.eq_as_configdict()` to check equality among `ConfigDict` and `FrozenConfigDict` objects. ```python import ml_collections dict_1 = {'list': [1, 2]} dict_2 = {'list': (1, 2)} cfg_1 = ml_collections.ConfigDict(dict_1) frozen_cfg_1 = ml_collections.FrozenConfigDict(dict_1) frozen_cfg_2 = ml_collections.FrozenConfigDict(dict_2) # True because FrozenConfigDict converts lists to tuples print(frozen_cfg_1.items() == frozen_cfg_2.items()) # False because == distinguishes the underlying difference print(frozen_cfg_1 == frozen_cfg_2) # False because == distinguishes these types print(frozen_cfg_1 == cfg_1) # But eq_as_configdict() treats both as ConfigDict, so these are True: print(frozen_cfg_1.eq_as_configdict(cfg_1)) print(cfg_1.eq_as_configdict(frozen_cfg_1)) ``` ### Equality checking with lazy computation Equality checks see if the computed values are the same. Equality is satisfied if two sets of computations are different as long as they result in the same value. ```python import ml_collections cfg_1 = ml_collections.ConfigDict() cfg_1.a = 1 cfg_1.b = cfg_1.get_ref('a') + 2 cfg_2 = ml_collections.ConfigDict() cfg_2.a = 1 cfg_2.b = cfg_2.get_ref('a') * 3 # True because all computed values are the same print(cfg_1 == cfg_2) ``` ### Locking and copying Here is an example with `lock()` and `deepcopy()`: ```python import copy import ml_collections cfg = ml_collections.ConfigDict() cfg.integer_field = 123 # Locking prohibits the addition and deletion of new fields but allows # modification of existing values. cfg.lock() try: cfg.integer_field = 124 # Raises AttributeError and suggests valid field. except AttributeError as e: print(e) with cfg.unlocked(): cfg.integer_field = 1555 # Works fine too. # Get a copy of the config dict. new_cfg = copy.deepcopy(cfg) new_cfg.integer_field = -123 # Works fine. print(cfg) ``` ### Dictionary attributes and initialization ```python import ml_collections referenced_dict = {'inner_float': 3.14} d = { 'referenced_dict_1': referenced_dict, 'referenced_dict_2': referenced_dict, 'list_containing_dict': [{'key': 'value'}], } # We can initialize on a dictionary cfg = ml_collections.ConfigDict(d) # Reference structure is preserved print(id(cfg.referenced_dict_1) == id(cfg.referenced_dict_2)) # True # And the dict attributes have been converted to ConfigDict print(type(cfg.referenced_dict_1)) # ConfigDict # However, the initialization does not look inside of lists, so dicts inside # lists are not converted to ConfigDict print(type(cfg.list_containing_dict[0])) # dict ``` ### More Examples TODO(mohitreddy): Add links for examples. For more examples, take a look at these `ml_collections/config_dict/examples/`. For examples and gotchas specifically about initializing a ConfigDict, see `ml_collections/config_dict/examples/config_dict_initialization.py`. ## Config Flags This library adds flag definitions to `absl.flags` to handle config files. It does not wrap `absl.flags` so if using any standard flag definitions alongside config file flags, users must also import `absl.flags`. Currently, this module adds two new flag types, namely `DEFINE_config_file` which accepts a path to a Python file that generates a configuration, and `DEFINE_config_dict` which accepts a configuration directly. Configurations are dict-like structures (see [ConfigDict](#configdict)) whose nested elements can be overridden using special command-line flags. See the examples below for more details. ### Usage Use `ml_collections.config_flags` alongside `absl.flags`. For example: `script.py`: ```python from absl import app from absl import flags from ml_collections.config_flags import config_flags FLAGS = flags.FLAGS config_flags.DEFINE_config_file('my_config') def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run() ``` `config.py`: ```python # Note that this is a valid Python script. # get_config() can return an arbitrary dict-like object. However, it is advised # to use ml_collections.ConfigDict. # See ml_collections/config_dict/examples/config_dict_basic.py import ml_collections def get_config(): config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) return config ``` Now, after running: ```bash python script.py -- --my_config=config.py \ --my_config.field1=8 \ --my_config.nested.field=2.1 \ --my_config.tuple='(1, 2, (1, 2))' ``` we get: ``` field1: 8 field2: tom nested: field: 2.1 tuple: !!python/tuple - 1 - 2 - !!python/tuple - 1 - 2 ``` Usage of `DEFINE_config_dict` is similar to `DEFINE_config_file`, the main difference is the configuration is defined in `script.py` instead of in a separate file. `script.py`: ```python from absl import app from absl import flags import ml_collections from ml_collections.config_flags import config_flags config = ml_collections.ConfigDict() config.field1 = 1 config.field2 = 'tom' config.nested = ml_collections.ConfigDict() config.nested.field = 2.23 config.tuple = (1, 2, 3) FLAGS = flags.FLAGS config_flags.DEFINE_config_dict('my_config', config) def main(_): print(FLAGS.my_config) if __name__ == '__main__': app.run() ``` `config_file` flags are compatible with the command-line flag syntax. All the following options are supported for non-boolean values in configurations: * `-(-)config.field=value` * `-(-)config.field value` Options for boolean values are slightly different: * `-(-)config.boolean_field`: set boolean value to True. * `-(-)noconfig.boolean_field`: set boolean value to False. * `-(-)config.boolean_field=value`: `value` is `true`, `false`, `True` or `False`. Note that `-(-)config.boolean_field value` is not supported. ### Parameterising the get_config() function It's sometimes useful to be able to pass parameters into `get_config`, and change what is returned based on this configuration. One example is if you are grid searching over parameters which have a different hierarchical structure - the flag needs to be present in the resulting ConfigDict. It would be possible to include the union of all possible leaf values in your ConfigDict, but this produces a confusing config result as you have to remember which parameters will actually have an effect and which won't. A better system is to pass some configuration, indicating which structure of ConfigDict should be returned. An example is the following config file: ```python import ml_collections def get_config(config_string): possible_structures = { 'linear': ml_collections.ConfigDict({ 'model_constructor': 'snt.Linear', 'model_config': ml_collections.ConfigDict({ 'output_size': 42, }), 'lstm': ml_collections.ConfigDict({ 'model_constructor': 'snt.LSTM', 'model_config': ml_collections.ConfigDict({ 'hidden_size': 108, }) }) } return possible_structures[config_string] ``` The value of `config_string` will be anything that is to the right of the first colon in the config file path, if one exists. If no colon exists, no value is passed to `get_config` (producing a TypeError if `get_config` expects a value.) The above example can be run like: ```bash python script.py -- --config=path_to_config.py:linear \ --config.model_config.output_size=256 ``` or like: ```bash python script.py -- --config=path_to_config.py:lstm \ --config.model_config.hidden_size=512 ``` ### Additional features * Loads any valid python script which defines `get_config()` function returning any python object. * Automatic locking of the loaded object, if the loaded object defines a callable `.lock()` method. * Supports command-line overriding of arbitrarily nested values in dict-like objects (with key/attribute based getters/setters) of the following types: * `types.IntType` (integer) * `types.FloatType` (float) * `types.BooleanType` (bool) * `types.StringType` (string) * `types.TupleType` (tuple) * Overriding is type safe. * Overriding of `TupleType` can be done by passing in the `tuple` as a string (see the example in the [Usage](#usage) section). * The overriding `tuple` object can be of a different size and have different types than the original. Nested tuples are also supported. Platform: UNKNOWN Classifier: Development Status :: 4 - Beta Classifier: Intended Audience :: Developers Classifier: Intended Audience :: Science/Research Classifier: License :: OSI Approved :: Apache Software License Classifier: Programming Language :: Python Classifier: Topic :: Scientific/Engineering Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence Classifier: Topic :: Software Development :: Libraries Classifier: Topic :: Software Development :: Libraries :: Python Modules Requires-Python: >=2.6 Description-Content-Type: text/markdown ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243892.0 ml_collections-0.1.0/ml_collections.egg-info/SOURCES.txt0000644162426002575230000000363300000000000026065 0ustar00mohitreddyprimarygroup00000000000000README.md setup.py ml_collections/__init__.py ml_collections.egg-info/PKG-INFO ml_collections.egg-info/SOURCES.txt ml_collections.egg-info/dependency_links.txt ml_collections.egg-info/not-zip-safe ml_collections.egg-info/requires.txt ml_collections.egg-info/top_level.txt ml_collections/config_dict/__init__.py ml_collections/config_dict/config_dict.py ml_collections/config_dict/examples/config.py ml_collections/config_dict/examples/config_dict_advanced.py ml_collections/config_dict/examples/config_dict_basic.py ml_collections/config_dict/examples/config_dict_initialization.py ml_collections/config_dict/examples/config_dict_lock.py ml_collections/config_dict/examples/config_dict_placeholder.py ml_collections/config_dict/examples/examples_test.py ml_collections/config_dict/examples/field_reference.py ml_collections/config_dict/examples/frozen_config_dict.py ml_collections/config_dict/tests/config_dict_test.py ml_collections/config_dict/tests/field_reference_test.py ml_collections/config_dict/tests/frozen_config_dict_test.py ml_collections/config_flags/__init__.py ml_collections/config_flags/config_flags.py ml_collections/config_flags/tuple_parser.py ml_collections/config_flags/examples/config.py ml_collections/config_flags/examples/define_config_dict_basic.py ml_collections/config_flags/examples/define_config_file_basic.py ml_collections/config_flags/examples/examples_test.py ml_collections/config_flags/examples/parameterised_config.py ml_collections/config_flags/tests/config_overriding_test.py ml_collections/config_flags/tests/configdict_config.py ml_collections/config_flags/tests/fieldreference_config.py ml_collections/config_flags/tests/ioerror_config.py ml_collections/config_flags/tests/mini_config.py ml_collections/config_flags/tests/mock_config.py ml_collections/config_flags/tests/parameterised_config.py ml_collections/config_flags/tests/typeerror_config.py ml_collections/config_flags/tests/valueerror_config.py././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243892.0 ml_collections-0.1.0/ml_collections.egg-info/dependency_links.txt0000644162426002575230000000000100000000000030242 0ustar00mohitreddyprimarygroup00000000000000 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243892.0 ml_collections-0.1.0/ml_collections.egg-info/not-zip-safe0000644162426002575230000000000100000000000026422 0ustar00mohitreddyprimarygroup00000000000000 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243892.0 ml_collections-0.1.0/ml_collections.egg-info/requires.txt0000644162426002575230000000010100000000000026564 0ustar00mohitreddyprimarygroup00000000000000absl-py PyYAML six contextlib2 [:python_version < "3.5"] typing ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243892.0 ml_collections-0.1.0/ml_collections.egg-info/top_level.txt0000644162426002575230000000001700000000000026724 0ustar00mohitreddyprimarygroup00000000000000ml_collections ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1598243892.235948 ml_collections-0.1.0/setup.cfg0000644162426002575230000000004600000000000021315 0ustar00mohitreddyprimarygroup00000000000000[egg_info] tag_build = tag_date = 0 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1598243773.0 ml_collections-0.1.0/setup.py0000644162426002575230000000410500000000000021206 0ustar00mohitreddyprimarygroup00000000000000# Copyright 2020 The ML Collections Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python 3 """Setup for pip package.""" from setuptools import find_namespace_packages from setuptools import setup def _parse_requirements(requirements_txt_path): with open(requirements_txt_path) as fp: return fp.read().splitlines() _VERSION = '0.1.0' setup( name='ml_collections', version=_VERSION, author='ML Collections Authors', author_email='ml-collections@google.com', description='ML Collections is a library of Python collections designed for ML usecases.', long_description=open('README.md').read(), long_description_content_type='text/markdown', url='https://github.com/google/ml_collections', license='Apache 2.0', # Contained modules and scripts. packages=find_namespace_packages(exclude=['*_test.py']), install_requires=_parse_requirements('requirements.txt'), tests_require=_parse_requirements('requirements-test.txt'), python_requires='>=2.6', include_package_data=True, zip_safe=False, # PyPI package information. classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', ], )