diff --git a/homeassistant/util/frozen_dataclass_compat.py b/homeassistant/util/frozen_dataclass_compat.py index e62e0a34cf1..58faedeea6f 100644 --- a/homeassistant/util/frozen_dataclass_compat.py +++ b/homeassistant/util/frozen_dataclass_compat.py @@ -59,7 +59,7 @@ class FrozenOrThawed(type): for base in bases: dataclass_bases.append(getattr(base, "_dataclass", base)) cls._dataclass = dataclasses.make_dataclass( - f"{name}_dataclass", class_fields, bases=tuple(dataclass_bases), frozen=True + name, class_fields, bases=tuple(dataclass_bases), frozen=True ) def __new__( @@ -87,15 +87,17 @@ class FrozenOrThawed(type): class will be a real dataclass, i.e. it's decorated with @dataclass. """ if not namespace["_FrozenOrThawed__frozen_or_thawed"]: - parent = cls.__mro__[1] # This class is a real dataclass, optionally inject the parent's annotations - if dataclasses.is_dataclass(parent) or not hasattr(parent, "_dataclass"): - # Rely on dataclass inheritance + if all(dataclasses.is_dataclass(base) for base in bases): + # All direct parents are dataclasses, rely on dataclass inheritance return - # Parent is not a dataclass, inject its annotations - cls.__annotations__ = ( - parent._dataclass.__annotations__ | cls.__annotations__ - ) + # Parent is not a dataclass, inject all parents' annotations + annotations: dict = {} + for parent in cls.__mro__[::-1]: + if parent is object: + continue + annotations |= parent.__annotations__ + cls.__annotations__ = annotations return # First try without setting the kw_only flag, and if that fails, try setting it @@ -104,30 +106,15 @@ class FrozenOrThawed(type): except TypeError: cls._make_dataclass(name, bases, True) - def __delattr__(self: object, name: str) -> None: - """Delete an attribute. + def __new__(*args: Any, **kwargs: Any) -> object: + """Create a new instance. - If self is a real dataclass, this is called if the dataclass is not frozen. - If self is not a real dataclass, forward to cls._dataclass.__delattr. + The function has no named arguments to avoid name collisions with dataclass + field names. """ - if dataclasses.is_dataclass(self): - return object.__delattr__(self, name) - return self._dataclass.__delattr__(self, name) # type: ignore[attr-defined, no-any-return] + cls, *_args = args + if dataclasses.is_dataclass(cls): + return object.__new__(cls) + return cls._dataclass(*_args, **kwargs) - def __setattr__(self: object, name: str, value: Any) -> None: - """Set an attribute. - - If self is a real dataclass, this is called if the dataclass is not frozen. - If self is not a real dataclass, forward to cls._dataclass.__setattr__. - """ - if dataclasses.is_dataclass(self): - return object.__setattr__(self, name, value) - return self._dataclass.__setattr__(self, name, value) # type: ignore[attr-defined, no-any-return] - - # Set generated dunder methods from the dataclass - # MyPy doesn't understand what's happening, so we ignore it - cls.__delattr__ = __delattr__ # type: ignore[assignment, method-assign] - cls.__eq__ = cls._dataclass.__eq__ # type: ignore[method-assign] - cls.__init__ = cls._dataclass.__init__ # type: ignore[misc] - cls.__repr__ = cls._dataclass.__repr__ # type: ignore[method-assign] - cls.__setattr__ = __setattr__ # type: ignore[assignment, method-assign] + cls.__new__ = __new__ # type: ignore[method-assign] diff --git a/tests/helpers/snapshots/test_entity.ambr b/tests/helpers/snapshots/test_entity.ambr index 3b04286b62f..7f146fa0494 100644 --- a/tests/helpers/snapshots/test_entity.ambr +++ b/tests/helpers/snapshots/test_entity.ambr @@ -1,6 +1,18 @@ # serializer version: 1 # name: test_entity_description_as_dataclass - EntityDescription(key='blah', device_class='test', entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name=, translation_key=None, unit_of_measurement=None) + dict({ + 'device_class': 'test', + 'entity_category': None, + 'entity_registry_enabled_default': True, + 'entity_registry_visible_default': True, + 'force_update': False, + 'has_entity_name': False, + 'icon': None, + 'key': 'blah', + 'name': , + 'translation_key': None, + 'unit_of_measurement': None, + }) # --- # name: test_entity_description_as_dataclass.1 "EntityDescription(key='blah', device_class='test', entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name=, translation_key=None, unit_of_measurement=None)" @@ -43,3 +55,63 @@ # name: test_extending_entity_description.3 "test_extending_entity_description..ThawedEntityDescription(key='blah', device_class=None, entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name='name', translation_key=None, unit_of_measurement=None, extra='foo')" # --- +# name: test_extending_entity_description.4 + dict({ + 'device_class': None, + 'entity_category': None, + 'entity_registry_enabled_default': True, + 'entity_registry_visible_default': True, + 'extension': 'ext', + 'extra': 'foo', + 'force_update': False, + 'has_entity_name': False, + 'icon': None, + 'key': 'blah', + 'name': 'name', + 'translation_key': None, + 'unit_of_measurement': None, + }) +# --- +# name: test_extending_entity_description.5 + "test_extending_entity_description..MyExtendedEntityDescription(key='blah', device_class=None, entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name='name', translation_key=None, unit_of_measurement=None, extension='ext', extra='foo')" +# --- +# name: test_extending_entity_description.6 + dict({ + 'device_class': None, + 'entity_category': None, + 'entity_registry_enabled_default': True, + 'entity_registry_visible_default': True, + 'extra': 'foo', + 'force_update': False, + 'has_entity_name': False, + 'icon': None, + 'key': 'blah', + 'mixin': 'mixin', + 'name': 'name', + 'translation_key': None, + 'unit_of_measurement': None, + }) +# --- +# name: test_extending_entity_description.7 + "test_extending_entity_description..ComplexEntityDescription1(mixin='mixin', key='blah', device_class=None, entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name='name', translation_key=None, unit_of_measurement=None, extra='foo')" +# --- +# name: test_extending_entity_description.8 + dict({ + 'device_class': None, + 'entity_category': None, + 'entity_registry_enabled_default': True, + 'entity_registry_visible_default': True, + 'extra': 'foo', + 'force_update': False, + 'has_entity_name': False, + 'icon': None, + 'key': 'blah', + 'mixin': 'mixin', + 'name': 'name', + 'translation_key': None, + 'unit_of_measurement': None, + }) +# --- +# name: test_extending_entity_description.9 + "test_extending_entity_description..ComplexEntityDescription2(mixin='mixin', key='blah', device_class=None, entity_category=None, entity_registry_enabled_default=True, entity_registry_visible_default=True, force_update=False, icon=None, has_entity_name=False, name='name', translation_key=None, unit_of_measurement=None, extra='foo')" +# --- diff --git a/tests/helpers/test_entity.py b/tests/helpers/test_entity.py index 66ba9f947c9..5a706b73b49 100644 --- a/tests/helpers/test_entity.py +++ b/tests/helpers/test_entity.py @@ -1669,6 +1669,7 @@ def test_entity_description_as_dataclass(snapshot: SnapshotAssertion): with pytest.raises(dataclasses.FrozenInstanceError): delattr(obj, "name") + assert dataclasses.is_dataclass(obj) assert obj == snapshot assert obj == entity.EntityDescription("blah", device_class="test") assert repr(obj) == snapshot @@ -1706,3 +1707,45 @@ def test_extending_entity_description(snapshot: SnapshotAssertion): assert obj.name == "mutate" delattr(obj, "key") assert not hasattr(obj, "key") + + # Try multiple levels of FrozenOrThawed + class ExtendedEntityDescription(entity.EntityDescription, frozen_or_thawed=True): + extension: str = None + + @dataclasses.dataclass(frozen=True) + class MyExtendedEntityDescription(ExtendedEntityDescription): + extra: str = None + + obj = MyExtendedEntityDescription("blah", extension="ext", extra="foo", name="name") + assert obj == snapshot + assert obj == MyExtendedEntityDescription( + "blah", extension="ext", extra="foo", name="name" + ) + assert repr(obj) == snapshot + + # Try multiple direct parents + @dataclasses.dataclass(frozen=True) + class MyMixin: + mixin: str = None + + @dataclasses.dataclass(frozen=True, kw_only=True) + class ComplexEntityDescription1(MyMixin, entity.EntityDescription): + extra: str = None + + obj = ComplexEntityDescription1(key="blah", extra="foo", mixin="mixin", name="name") + assert obj == snapshot + assert obj == ComplexEntityDescription1( + key="blah", extra="foo", mixin="mixin", name="name" + ) + assert repr(obj) == snapshot + + @dataclasses.dataclass(frozen=True, kw_only=True) + class ComplexEntityDescription2(entity.EntityDescription, MyMixin): + extra: str = None + + obj = ComplexEntityDescription2(key="blah", extra="foo", mixin="mixin", name="name") + assert obj == snapshot + assert obj == ComplexEntityDescription2( + key="blah", extra="foo", mixin="mixin", name="name" + ) + assert repr(obj) == snapshot