Fix esphome state mapping (#74337)

This commit is contained in:
J. Nick Koston 2022-07-03 15:48:34 -05:00 committed by GitHub
parent 30a5df5895
commit 40ed44cbea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 48 deletions

View File

@ -558,7 +558,7 @@ async def platform_async_setup_entry(
entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry)
entry_data.info[component_key] = {}
entry_data.old_info[component_key] = {}
entry_data.state[component_key] = {}
entry_data.state.setdefault(state_type, {})
@callback
def async_list_entities(infos: list[EntityInfo]) -> None:
@ -578,7 +578,7 @@ async def platform_async_setup_entry(
old_infos.pop(info.key)
else:
# Create new entity
entity = entity_type(entry_data, component_key, info.key)
entity = entity_type(entry_data, component_key, info.key, state_type)
add_entities.append(entity)
new_infos[info.key] = info
@ -677,12 +677,17 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
"""Define a base esphome entity."""
def __init__(
self, entry_data: RuntimeEntryData, component_key: str, key: int
self,
entry_data: RuntimeEntryData,
component_key: str,
key: int,
state_type: type[_StateT],
) -> None:
"""Initialize."""
self._entry_data = entry_data
self._component_key = component_key
self._key = key
self._state_type = state_type
async def async_added_to_hass(self) -> None:
"""Register callbacks."""
@ -707,7 +712,7 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
self.async_on_remove(
self._entry_data.async_subscribe_state_update(
self._component_key, self._key, self._on_state_update
self._state_type, self._key, self._on_state_update
)
)
@ -755,11 +760,11 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
@property
def _state(self) -> _StateT:
return cast(_StateT, self._entry_data.state[self._component_key][self._key])
return cast(_StateT, self._entry_data.state[self._state_type][self._key])
@property
def _has_state(self) -> bool:
return self._key in self._entry_data.state[self._component_key]
return self._key in self._entry_data.state[self._state_type]
@property
def available(self) -> bool:

View File

@ -12,34 +12,21 @@ from aioesphomeapi import (
APIClient,
APIVersion,
BinarySensorInfo,
BinarySensorState,
CameraInfo,
CameraState,
ClimateInfo,
ClimateState,
CoverInfo,
CoverState,
DeviceInfo,
EntityInfo,
EntityState,
FanInfo,
FanState,
LightInfo,
LightState,
LockInfo,
LockState,
MediaPlayerInfo,
MediaPlayerState,
NumberInfo,
NumberState,
SelectInfo,
SelectState,
SensorInfo,
SensorState,
SwitchInfo,
SwitchState,
TextSensorInfo,
TextSensorState,
UserService,
)
from aioesphomeapi.model import ButtonInfo
@ -56,8 +43,8 @@ _LOGGER = logging.getLogger(__name__)
# Mapping from ESPHome info type to HA platform
INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], str] = {
BinarySensorInfo: Platform.BINARY_SENSOR,
ButtonInfo: Platform.BINARY_SENSOR,
CameraInfo: Platform.BINARY_SENSOR,
ButtonInfo: Platform.BUTTON,
CameraInfo: Platform.CAMERA,
ClimateInfo: Platform.CLIMATE,
CoverInfo: Platform.COVER,
FanInfo: Platform.FAN,
@ -71,23 +58,6 @@ INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], str] = {
TextSensorInfo: Platform.SENSOR,
}
STATE_TYPE_TO_COMPONENT_KEY = {
BinarySensorState: Platform.BINARY_SENSOR,
EntityState: Platform.BINARY_SENSOR,
CameraState: Platform.BINARY_SENSOR,
ClimateState: Platform.CLIMATE,
CoverState: Platform.COVER,
FanState: Platform.FAN,
LightState: Platform.LIGHT,
LockState: Platform.LOCK,
MediaPlayerState: Platform.MEDIA_PLAYER,
NumberState: Platform.NUMBER,
SelectState: Platform.SELECT,
SensorState: Platform.SENSOR,
SwitchState: Platform.SWITCH,
TextSensorState: Platform.SENSOR,
}
@dataclass
class RuntimeEntryData:
@ -96,7 +66,7 @@ class RuntimeEntryData:
entry_id: str
client: APIClient
store: Store
state: dict[str, dict[int, EntityState]] = field(default_factory=dict)
state: dict[type[EntityState], dict[int, EntityState]] = field(default_factory=dict)
info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict)
# A second list of EntityInfo objects
@ -111,9 +81,9 @@ class RuntimeEntryData:
api_version: APIVersion = field(default_factory=APIVersion)
cleanup_callbacks: list[Callable[[], None]] = field(default_factory=list)
disconnect_callbacks: list[Callable[[], None]] = field(default_factory=list)
state_subscriptions: dict[tuple[str, int], Callable[[], None]] = field(
default_factory=dict
)
state_subscriptions: dict[
tuple[type[EntityState], int], Callable[[], None]
] = field(default_factory=dict)
loaded_platforms: set[str] = field(default_factory=set)
platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
_storage_contents: dict[str, Any] | None = None
@ -160,24 +130,23 @@ class RuntimeEntryData:
@callback
def async_subscribe_state_update(
self,
component_key: str,
state_type: type[EntityState],
state_key: int,
entity_callback: Callable[[], None],
) -> Callable[[], None]:
"""Subscribe to state updates."""
def _unsubscribe() -> None:
self.state_subscriptions.pop((component_key, state_key))
self.state_subscriptions.pop((state_type, state_key))
self.state_subscriptions[(component_key, state_key)] = entity_callback
self.state_subscriptions[(state_type, state_key)] = entity_callback
return _unsubscribe
@callback
def async_update_state(self, state: EntityState) -> None:
"""Distribute an update of state information to the target."""
component_key = STATE_TYPE_TO_COMPONENT_KEY[type(state)]
subscription_key = (component_key, state.key)
self.state[component_key][state.key] = state
subscription_key = (type(state), state.key)
self.state[type(state)][state.key] = state
_LOGGER.debug(
"Dispatching update with key %s: %s",
subscription_key,