diff --git a/homeassistant/helpers/entity_registry.py b/homeassistant/helpers/entity_registry.py index 0cb668a5ffd..8664950c1b8 100644 --- a/homeassistant/helpers/entity_registry.py +++ b/homeassistant/helpers/entity_registry.py @@ -1803,6 +1803,9 @@ def _async_setup_cleanup(hass: HomeAssistant, registry: EntityRegistry) -> None: def _async_setup_entity_restore(hass: HomeAssistant, registry: EntityRegistry) -> None: """Set up the entity restore mechanism.""" + # pylint: disable-next=import-outside-toplevel + from . import entity + @callback def cleanup_restored_states_filter(event_data: Mapping[str, Any]) -> bool: """Clean up restored states filter.""" @@ -1816,6 +1819,7 @@ def _async_setup_entity_restore(hass: HomeAssistant, registry: EntityRegistry) - if state is None or not state.attributes.get(ATTR_RESTORED): return + del entity.entity_sources(hass)[event.data["entity_id"]] hass.states.async_remove(event.data["entity_id"], context=event.context) hass.bus.async_listen( @@ -1832,10 +1836,18 @@ def _async_setup_entity_restore(hass: HomeAssistant, registry: EntityRegistry) - """Make sure state machine contains entry for each registered entity.""" existing = set(hass.states.async_entity_ids()) + entity_sources = entity.entity_sources(hass) for entry in registry.entities.values(): if entry.entity_id in existing or entry.disabled: continue + entity_info: entity.EntityInfo = { + "domain": entry.platform, + } + if entry.config_entry_id: + entity_info["config_entry"] = entry.config_entry_id + + entity_sources[entry.entity_id] = entity_info entry.write_unavailable_state(hass) hass.bus.async_listen(EVENT_HOMEASSISTANT_START, _write_unavailable_states) diff --git a/tests/helpers/test_entity_registry.py b/tests/helpers/test_entity_registry.py index 714dfed32e9..d5e24511f76 100644 --- a/tests/helpers/test_entity_registry.py +++ b/tests/helpers/test_entity_registry.py @@ -18,7 +18,7 @@ from homeassistant.const import ( ) from homeassistant.core import CoreState, HomeAssistant, callback from homeassistant.exceptions import MaxLengthExceeded -from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers import device_registry as dr, entity, entity_registry as er from homeassistant.util.dt import utc_from_timestamp, utcnow from tests.common import ( @@ -1470,6 +1470,7 @@ async def test_restore_states( hass: HomeAssistant, entity_registry: er.EntityRegistry ) -> None: """Test restoring states.""" + entity_sources = entity.entity_sources(hass) hass.set_state(CoreState.not_running) entity_registry.async_get_or_create( @@ -1486,18 +1487,23 @@ async def test_restore_states( suggested_object_id="disabled", disabled_by=er.RegistryEntryDisabler.HASS, ) + config_entry = MockConfigEntry(domain="hue") + config_entry.add_to_hass(hass) entity_registry.async_get_or_create( "light", "hue", "9012", suggested_object_id="all_info_set", capabilities={"max": 100}, + config_entry=config_entry, supported_features=5, original_device_class="mock-device-class", original_name="Mock Original Name", original_icon="hass:original-icon", ) + assert entity_sources == {} + hass.bus.async_fire(EVENT_HOMEASSISTANT_START, {}) await hass.async_block_till_done() @@ -1521,6 +1527,16 @@ async def test_restore_states( "icon": "hass:original-icon", } + assert entity_sources == { + "light.all_info_set": { + "config_entry": config_entry.entry_id, + "domain": "hue", + }, + "light.simple": { + "domain": "hue", + }, + } + entity_registry.async_remove("light.disabled") entity_registry.async_remove("light.simple") entity_registry.async_remove("light.all_info_set") @@ -1531,6 +1547,8 @@ async def test_restore_states( assert hass.states.get("light.disabled") is None assert hass.states.get("light.all_info_set") is None + assert entity_sources == {} + async def test_remove_device_removes_entities( hass: HomeAssistant,