Fix key collision between platforms in esphome state updates (#74273)

This commit is contained in:
J. Nick Koston 2022-07-01 00:19:40 -05:00 committed by GitHub
parent 43595f7e17
commit 7655b84494
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 73 additions and 38 deletions

View File

@ -150,11 +150,6 @@ async def async_setup_entry( # noqa: C901
hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, on_stop) hass.bus.async_listen(EVENT_HOMEASSISTANT_STOP, on_stop)
) )
@callback
def async_on_state(state: EntityState) -> None:
"""Send dispatcher updates when a new state is received."""
entry_data.async_update_state(hass, state)
@callback @callback
def async_on_service_call(service: HomeassistantServiceCall) -> None: def async_on_service_call(service: HomeassistantServiceCall) -> None:
"""Call service when user automation in ESPHome config is triggered.""" """Call service when user automation in ESPHome config is triggered."""
@ -288,7 +283,7 @@ async def async_setup_entry( # noqa: C901
entity_infos, services = await cli.list_entities_services() entity_infos, services = await cli.list_entities_services()
await entry_data.async_update_static_infos(hass, entry, entity_infos) await entry_data.async_update_static_infos(hass, entry, entity_infos)
await _setup_services(hass, entry_data, services) await _setup_services(hass, entry_data, services)
await cli.subscribe_states(async_on_state) await cli.subscribe_states(entry_data.async_update_state)
await cli.subscribe_service_calls(async_on_service_call) await cli.subscribe_service_calls(async_on_service_call)
await cli.subscribe_home_assistant_states(async_on_state_subscription) await cli.subscribe_home_assistant_states(async_on_state_subscription)
@ -568,7 +563,6 @@ async def platform_async_setup_entry(
@callback @callback
def async_list_entities(infos: list[EntityInfo]) -> None: def async_list_entities(infos: list[EntityInfo]) -> None:
"""Update entities of this platform when entities are listed.""" """Update entities of this platform when entities are listed."""
key_to_component = entry_data.key_to_component
old_infos = entry_data.info[component_key] old_infos = entry_data.info[component_key]
new_infos: dict[int, EntityInfo] = {} new_infos: dict[int, EntityInfo] = {}
add_entities = [] add_entities = []
@ -587,12 +581,10 @@ async def platform_async_setup_entry(
entity = entity_type(entry_data, component_key, info.key) entity = entity_type(entry_data, component_key, info.key)
add_entities.append(entity) add_entities.append(entity)
new_infos[info.key] = info new_infos[info.key] = info
key_to_component[info.key] = component_key
# Remove old entities # Remove old entities
for info in old_infos.values(): for info in old_infos.values():
entry_data.async_remove_entity(hass, component_key, info.key) entry_data.async_remove_entity(hass, component_key, info.key)
key_to_component.pop(info.key, None)
# First copy the now-old info into the backup object # First copy the now-old info into the backup object
entry_data.old_info[component_key] = entry_data.info[component_key] entry_data.old_info[component_key] = entry_data.info[component_key]
@ -714,13 +706,8 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
) )
self.async_on_remove( self.async_on_remove(
async_dispatcher_connect( self._entry_data.async_subscribe_state_update(
self.hass, self._component_key, self._key, self._on_state_update
(
f"esphome_{self._entry_id}"
f"_update_{self._component_key}_{self._key}"
),
self._on_state_update,
) )
) )

View File

@ -12,26 +12,40 @@ from aioesphomeapi import (
APIClient, APIClient,
APIVersion, APIVersion,
BinarySensorInfo, BinarySensorInfo,
BinarySensorState,
CameraInfo, CameraInfo,
CameraState,
ClimateInfo, ClimateInfo,
ClimateState,
CoverInfo, CoverInfo,
CoverState,
DeviceInfo, DeviceInfo,
EntityInfo, EntityInfo,
EntityState, EntityState,
FanInfo, FanInfo,
FanState,
LightInfo, LightInfo,
LightState,
LockInfo, LockInfo,
LockState,
MediaPlayerInfo, MediaPlayerInfo,
MediaPlayerState,
NumberInfo, NumberInfo,
NumberState,
SelectInfo, SelectInfo,
SelectState,
SensorInfo, SensorInfo,
SensorState,
SwitchInfo, SwitchInfo,
SwitchState,
TextSensorInfo, TextSensorInfo,
TextSensorState,
UserService, UserService,
) )
from aioesphomeapi.model import ButtonInfo from aioesphomeapi.model import ButtonInfo
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
@ -41,20 +55,37 @@ _LOGGER = logging.getLogger(__name__)
# Mapping from ESPHome info type to HA platform # Mapping from ESPHome info type to HA platform
INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], str] = { INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], str] = {
BinarySensorInfo: "binary_sensor", BinarySensorInfo: Platform.BINARY_SENSOR,
ButtonInfo: "button", ButtonInfo: Platform.BINARY_SENSOR,
CameraInfo: "camera", CameraInfo: Platform.BINARY_SENSOR,
ClimateInfo: "climate", ClimateInfo: Platform.CLIMATE,
CoverInfo: "cover", CoverInfo: Platform.COVER,
FanInfo: "fan", FanInfo: Platform.FAN,
LightInfo: "light", LightInfo: Platform.LIGHT,
LockInfo: "lock", LockInfo: Platform.LOCK,
MediaPlayerInfo: "media_player", MediaPlayerInfo: Platform.MEDIA_PLAYER,
NumberInfo: "number", NumberInfo: Platform.NUMBER,
SelectInfo: "select", SelectInfo: Platform.SELECT,
SensorInfo: "sensor", SensorInfo: Platform.SENSOR,
SwitchInfo: "switch", SwitchInfo: Platform.SWITCH,
TextSensorInfo: "sensor", TextSensorInfo: Platform.SENSOR,
}
STATE_TYPE_TO_COMPONENT_KEY = {
BinarySensorState: Platform.BINARY_SENSOR,
EntityState: Platform.BINARY_SENSOR,
CameraState: Platform.BINARY_SENSOR,
ClimateState: Platform.CLIMATE,
CoverState: Platform.COVER,
FanState: Platform.FAN,
LightState: Platform.LIGHT,
LockState: Platform.LOCK,
MediaPlayerState: Platform.MEDIA_PLAYER,
NumberState: Platform.NUMBER,
SelectState: Platform.SELECT,
SensorState: Platform.SENSOR,
SwitchState: Platform.SWITCH,
TextSensorState: Platform.SENSOR,
} }
@ -67,7 +98,6 @@ class RuntimeEntryData:
store: Store store: Store
state: dict[str, dict[int, EntityState]] = field(default_factory=dict) state: dict[str, dict[int, EntityState]] = field(default_factory=dict)
info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict) info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict)
key_to_component: dict[int, str] = field(default_factory=dict)
# A second list of EntityInfo objects # A second list of EntityInfo objects
# This is necessary for when an entity is being removed. HA requires # This is necessary for when an entity is being removed. HA requires
@ -81,6 +111,9 @@ class RuntimeEntryData:
api_version: APIVersion = field(default_factory=APIVersion) api_version: APIVersion = field(default_factory=APIVersion)
cleanup_callbacks: list[Callable[[], None]] = field(default_factory=list) cleanup_callbacks: list[Callable[[], None]] = field(default_factory=list)
disconnect_callbacks: list[Callable[[], None]] = field(default_factory=list) disconnect_callbacks: list[Callable[[], None]] = field(default_factory=list)
state_subscriptions: dict[tuple[str, int], Callable[[], None]] = field(
default_factory=dict
)
loaded_platforms: set[str] = field(default_factory=set) loaded_platforms: set[str] = field(default_factory=set)
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock) platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_storage_contents: dict[str, Any] | None = None _storage_contents: dict[str, Any] | None = None
@ -125,18 +158,33 @@ class RuntimeEntryData:
async_dispatcher_send(hass, signal, infos) async_dispatcher_send(hass, signal, infos)
@callback @callback
def async_update_state(self, hass: HomeAssistant, state: EntityState) -> None: def async_subscribe_state_update(
self,
component_key: str,
state_key: int,
entity_callback: Callable[[], None],
) -> Callable[[], None]:
"""Subscribe to state updates."""
def _unsubscribe() -> None:
self.state_subscriptions.pop((component_key, state_key))
self.state_subscriptions[(component_key, state_key)] = entity_callback
return _unsubscribe
@callback
def async_update_state(self, state: EntityState) -> None:
"""Distribute an update of state information to the target.""" """Distribute an update of state information to the target."""
component_key = self.key_to_component[state.key] component_key = STATE_TYPE_TO_COMPONENT_KEY[type(state)]
subscription_key = (component_key, state.key)
self.state[component_key][state.key] = state self.state[component_key][state.key] = state
signal = f"esphome_{self.entry_id}_update_{component_key}_{state.key}"
_LOGGER.debug( _LOGGER.debug(
"Dispatching update for component %s with state key %s: %s", "Dispatching update with key %s: %s",
component_key, subscription_key,
state.key,
state, state,
) )
async_dispatcher_send(hass, signal) if subscription_key in self.state_subscriptions:
self.state_subscriptions[subscription_key]()
@callback @callback
def async_update_device_state(self, hass: HomeAssistant) -> None: def async_update_device_state(self, hass: HomeAssistant) -> None: