Callback esphome EntityInfo by platform instead of all platforms (#95021)

This commit is contained in:
J. Nick Koston 2023-06-22 09:39:48 +02:00 committed by GitHub
parent 05c25d2349
commit adc2df6b8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 8 deletions

View File

@ -703,10 +703,6 @@ async def platform_async_setup_entry(
new_infos: dict[int, EntityInfo] = {}
add_entities: list[_EntityT] = []
for info in infos:
if not isinstance(info, info_type):
# Filter out infos that don't belong to this platform.
continue
if info.key in old_infos:
# Update existing entity
old_infos.pop(info.key)
@ -737,9 +733,7 @@ async def platform_async_setup_entry(
async_add_entities(add_entities)
entry_data.cleanup_callbacks.append(
async_dispatcher_connect(
hass, entry_data.signal_static_info_updated, async_list_entities
)
entry_data.async_register_static_info_callback(info_type, async_list_entities)
)

View File

@ -35,7 +35,7 @@ from aioesphomeapi.model import ButtonInfo
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.storage import Store
@ -106,6 +106,9 @@ class RuntimeEntryData:
default_factory=list
)
assist_pipeline_state: bool = False
entity_info_callbacks: dict[
type[EntityInfo], list[Callable[[list[EntityInfo]], None]]
] = field(default_factory=dict)
@property
def name(self) -> str:
@ -135,6 +138,21 @@ class RuntimeEntryData:
"""Return the signal to listen to for updates on static info for a specific component_key and key."""
return f"esphome_{self.entry_id}_static_info_updated_{component_key}_{key}"
@callback
def async_register_static_info_callback(
self,
entity_info_type: type[EntityInfo],
callback_: Callable[[list[EntityInfo]], None],
) -> CALLBACK_TYPE:
"""Register to receive callbacks when static info changes for an EntityInfo type."""
callbacks = self.entity_info_callbacks.setdefault(entity_info_type, [])
callbacks.append(callback_)
def _unsub() -> None:
callbacks.remove(callback_)
return _unsub
@callback
def async_update_ble_connection_limits(self, free: int, limit: int) -> None:
"""Update the BLE connection limits."""
@ -222,6 +240,21 @@ class RuntimeEntryData:
break
await self._ensure_platforms_loaded(hass, entry, needed_platforms)
# Make a dict of the EntityInfo by type and send
# them to the listeners for each specific EntityInfo type
infos_by_type: dict[type[EntityInfo], list[EntityInfo]] = {}
for info in infos:
info_type = type(info)
if info_type not in infos_by_type:
infos_by_type[info_type] = []
infos_by_type[info_type].append(info)
callbacks_by_type = self.entity_info_callbacks
for type_, entity_infos in infos_by_type.items():
if callbacks_ := callbacks_by_type.get(type_):
for callback_ in callbacks_:
callback_(entity_infos)
# Then send dispatcher event
async_dispatcher_send(hass, self.signal_static_info_updated, infos)