diff --git a/gin/config.py b/gin/config.py index 73763bc..2a116a5 100644 --- a/gin/config.py +++ b/gin/config.py @@ -2796,6 +2796,44 @@ def decorator(cls, module=module): return decorator(cls) +def register_enum(cls=None, module=None): + """Decorator for register an enum class. + + This essentially bypasses the limitation of enums which forbid inheritance + whenever an attribute is defined and thus prevents decoration with the + main register function. + + Generated constants have format `module.ClassName`. The module + name is optional when using the constant. + + Args: + cls: Class type. + module: The module to associate with the constants, to help handle naming + collisions. If `None`, `cls.__module__` will be used. + + Returns: + Class type (identity function). + + Raises: + TypeError: When applied to a non-enum class. + """ + def decorator(cls, module=module): + if not issubclass(cls, enum.Enum): + raise TypeError("Class '{}' is not subclass of enum.".format( + cls.__name__)) + + if module is None: + module = cls.__module__ + for value in cls: + constant('{}.{}'.format(module, cls.__name__), value.__class__) + break + return cls + + if cls is None: + return decorator + return decorator(cls) + + @register_finalize_hook def validate_macros_hook(config): for ref in iterate_references(config, to=get_configurable(macro)): diff --git a/tests/config_test.py b/tests/config_test.py index 8d5c64c..68b8d82 100644 --- a/tests/config_test.py +++ b/tests/config_test.py @@ -2275,6 +2275,25 @@ def testQueryConstant(self): self.assertEqual(0, config.query_parameter('OLD.ANSWER')) self.assertEqual(10, config.query_parameter('NEW.ANSWER')) + + def testRegisterEnum(self): + + @config.register_enum(module='enum_module') + class SomeEnum(enum.Enum): + FOO = 'foo' + BAR = 'bar' + + @config.configurable + def baz(a): + return a + + config.parse_config("baz.a = %enum_module.SomeEnum") + # pylint: disable=no-value-for-parameter + a = baz() + # pylint: enable=no-value-for-parameter + self.assertEqual(a, SomeEnum) + + def testConstantsFromEnum(self): @config.constants_from_enum(module='enum_module')