From adc2df6b8ebb6056a4353577ed065941f47dfab4 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Thu, 22 Jun 2023 09:39:48 +0200 Subject: [PATCH] Callback esphome EntityInfo by platform instead of all platforms (#95021) --- homeassistant/components/esphome/__init__.py | 8 +---- .../components/esphome/entry_data.py | 35 ++++++++++++++++++- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index c91f63787f7..0c962d82074 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -703,10 +703,6 @@ async def platform_async_setup_entry( new_infos: dict[int, EntityInfo] = {} add_entities: list[_EntityT] = [] for info in infos: - if not isinstance(info, info_type): - # Filter out infos that don't belong to this platform. - continue - if info.key in old_infos: # Update existing entity old_infos.pop(info.key) @@ -737,9 +733,7 @@ async def platform_async_setup_entry( async_add_entities(add_entities) entry_data.cleanup_callbacks.append( - async_dispatcher_connect( - hass, entry_data.signal_static_info_updated, async_list_entities - ) + entry_data.async_register_static_info_callback(info_type, async_list_entities) ) diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index 4b4b359e15b..41c5687e661 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -35,7 +35,7 @@ from aioesphomeapi.model import ButtonInfo from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.storage import Store @@ -106,6 +106,9 @@ class RuntimeEntryData: default_factory=list ) assist_pipeline_state: bool = False + entity_info_callbacks: dict[ + type[EntityInfo], list[Callable[[list[EntityInfo]], None]] + ] = field(default_factory=dict) @property def name(self) -> str: @@ -135,6 +138,21 @@ class RuntimeEntryData: """Return the signal to listen to for updates on static info for a specific component_key and key.""" return f"esphome_{self.entry_id}_static_info_updated_{component_key}_{key}" + @callback + def async_register_static_info_callback( + self, + entity_info_type: type[EntityInfo], + callback_: Callable[[list[EntityInfo]], None], + ) -> CALLBACK_TYPE: + """Register to receive callbacks when static info changes for an EntityInfo type.""" + callbacks = self.entity_info_callbacks.setdefault(entity_info_type, []) + callbacks.append(callback_) + + def _unsub() -> None: + callbacks.remove(callback_) + + return _unsub + @callback def async_update_ble_connection_limits(self, free: int, limit: int) -> None: """Update the BLE connection limits.""" @@ -222,6 +240,21 @@ class RuntimeEntryData: break await self._ensure_platforms_loaded(hass, entry, needed_platforms) + # Make a dict of the EntityInfo by type and send + # them to the listeners for each specific EntityInfo type + infos_by_type: dict[type[EntityInfo], list[EntityInfo]] = {} + for info in infos: + info_type = type(info) + if info_type not in infos_by_type: + infos_by_type[info_type] = [] + infos_by_type[info_type].append(info) + + callbacks_by_type = self.entity_info_callbacks + for type_, entity_infos in infos_by_type.items(): + if callbacks_ := callbacks_by_type.get(type_): + for callback_ in callbacks_: + callback_(entity_infos) + # Then send dispatcher event async_dispatcher_send(hass, self.signal_static_info_updated, infos)