diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index 0c1eac3aa45..ddedaf11ceb 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -558,7 +558,7 @@ async def platform_async_setup_entry( entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry) entry_data.info[component_key] = {} entry_data.old_info[component_key] = {} - entry_data.state[component_key] = {} + entry_data.state.setdefault(state_type, {}) @callback def async_list_entities(infos: list[EntityInfo]) -> None: @@ -578,7 +578,7 @@ async def platform_async_setup_entry( old_infos.pop(info.key) else: # Create new entity - entity = entity_type(entry_data, component_key, info.key) + entity = entity_type(entry_data, component_key, info.key, state_type) add_entities.append(entity) new_infos[info.key] = info @@ -677,12 +677,17 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): """Define a base esphome entity.""" def __init__( - self, entry_data: RuntimeEntryData, component_key: str, key: int + self, + entry_data: RuntimeEntryData, + component_key: str, + key: int, + state_type: type[_StateT], ) -> None: """Initialize.""" self._entry_data = entry_data self._component_key = component_key self._key = key + self._state_type = state_type async def async_added_to_hass(self) -> None: """Register callbacks.""" @@ -707,7 +712,7 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): self.async_on_remove( self._entry_data.async_subscribe_state_update( - self._component_key, self._key, self._on_state_update + self._state_type, self._key, self._on_state_update ) ) @@ -755,11 +760,11 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): @property def _state(self) -> _StateT: - return cast(_StateT, self._entry_data.state[self._component_key][self._key]) + return cast(_StateT, self._entry_data.state[self._state_type][self._key]) @property def _has_state(self) -> bool: - return self._key in self._entry_data.state[self._component_key] + return self._key in self._entry_data.state[self._state_type] @property def available(self) -> bool: diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index 8eb56e6fdb6..41a0e89245e 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -12,34 +12,21 @@ from aioesphomeapi import ( APIClient, APIVersion, BinarySensorInfo, - BinarySensorState, CameraInfo, - CameraState, ClimateInfo, - ClimateState, CoverInfo, - CoverState, DeviceInfo, EntityInfo, EntityState, FanInfo, - FanState, LightInfo, - LightState, LockInfo, - LockState, MediaPlayerInfo, - MediaPlayerState, NumberInfo, - NumberState, SelectInfo, - SelectState, SensorInfo, - SensorState, SwitchInfo, - SwitchState, TextSensorInfo, - TextSensorState, UserService, ) from aioesphomeapi.model import ButtonInfo @@ -56,8 +43,8 @@ _LOGGER = logging.getLogger(__name__) # Mapping from ESPHome info type to HA platform INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], str] = { BinarySensorInfo: Platform.BINARY_SENSOR, - ButtonInfo: Platform.BINARY_SENSOR, - CameraInfo: Platform.BINARY_SENSOR, + ButtonInfo: Platform.BUTTON, + CameraInfo: Platform.CAMERA, ClimateInfo: Platform.CLIMATE, CoverInfo: Platform.COVER, FanInfo: Platform.FAN, @@ -71,23 +58,6 @@ INFO_TYPE_TO_PLATFORM: dict[type[EntityInfo], str] = { TextSensorInfo: Platform.SENSOR, } -STATE_TYPE_TO_COMPONENT_KEY = { - BinarySensorState: Platform.BINARY_SENSOR, - EntityState: Platform.BINARY_SENSOR, - CameraState: Platform.BINARY_SENSOR, - ClimateState: Platform.CLIMATE, - CoverState: Platform.COVER, - FanState: Platform.FAN, - LightState: Platform.LIGHT, - LockState: Platform.LOCK, - MediaPlayerState: Platform.MEDIA_PLAYER, - NumberState: Platform.NUMBER, - SelectState: Platform.SELECT, - SensorState: Platform.SENSOR, - SwitchState: Platform.SWITCH, - TextSensorState: Platform.SENSOR, -} - @dataclass class RuntimeEntryData: @@ -96,7 +66,7 @@ class RuntimeEntryData: entry_id: str client: APIClient store: Store - state: dict[str, dict[int, EntityState]] = field(default_factory=dict) + state: dict[type[EntityState], dict[int, EntityState]] = field(default_factory=dict) info: dict[str, dict[int, EntityInfo]] = field(default_factory=dict) # A second list of EntityInfo objects @@ -111,9 +81,9 @@ class RuntimeEntryData: api_version: APIVersion = field(default_factory=APIVersion) cleanup_callbacks: list[Callable[[], None]] = field(default_factory=list) disconnect_callbacks: list[Callable[[], None]] = field(default_factory=list) - state_subscriptions: dict[tuple[str, int], Callable[[], None]] = field( - default_factory=dict - ) + state_subscriptions: dict[ + tuple[type[EntityState], int], Callable[[], None] + ] = field(default_factory=dict) loaded_platforms: set[str] = field(default_factory=set) platform_load_lock: asyncio.Lock = field(default_factory=asyncio.Lock) _storage_contents: dict[str, Any] | None = None @@ -160,24 +130,23 @@ class RuntimeEntryData: @callback def async_subscribe_state_update( self, - component_key: str, + state_type: type[EntityState], state_key: int, entity_callback: Callable[[], None], ) -> Callable[[], None]: """Subscribe to state updates.""" def _unsubscribe() -> None: - self.state_subscriptions.pop((component_key, state_key)) + self.state_subscriptions.pop((state_type, state_key)) - self.state_subscriptions[(component_key, state_key)] = entity_callback + self.state_subscriptions[(state_type, state_key)] = entity_callback return _unsubscribe @callback def async_update_state(self, state: EntityState) -> None: """Distribute an update of state information to the target.""" - component_key = STATE_TYPE_TO_COMPONENT_KEY[type(state)] - subscription_key = (component_key, state.key) - self.state[component_key][state.key] = state + subscription_key = (type(state), state.key) + self.state[type(state)][state.key] = state _LOGGER.debug( "Dispatching update with key %s: %s", subscription_key,