diff --git a/.coveragerc b/.coveragerc index 944b2a5f838..17af3760817 100644 --- a/.coveragerc +++ b/.coveragerc @@ -319,7 +319,6 @@ omit = homeassistant/components/esphome/lock.py homeassistant/components/esphome/media_player.py homeassistant/components/esphome/number.py - homeassistant/components/esphome/select.py homeassistant/components/esphome/sensor.py homeassistant/components/esphome/switch.py homeassistant/components/etherscan/sensor.py diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index bfd023a9980..c91f63787f7 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -845,6 +845,7 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): self._on_static_info_update, ) ) + self._update_state_from_entry_data() @callback def _on_static_info_update(self, static_info: EntityInfo) -> None: @@ -868,11 +869,9 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): self._attr_icon = None @callback - def _on_state_update(self) -> None: - """Call when state changed. + def _update_state_from_entry_data(self) -> None: + """Update state from entry data.""" - Behavior can be changed in child classes - """ state = self._entry_data.state key = self._key state_type = self._state_type @@ -880,6 +879,14 @@ class EsphomeEntity(Entity, Generic[_InfoT, _StateT]): if has_state: self._state = cast(_StateT, state[state_type][key]) self._has_state = has_state + + @callback + def _on_state_update(self) -> None: + """Call when state changed. + + Behavior can be changed in child classes + """ + self._update_state_from_entry_data() self.async_write_ha_state() @callback diff --git a/homeassistant/components/esphome/select.py b/homeassistant/components/esphome/select.py index e4cac21dbc8..d7cecf07d9e 100644 --- a/homeassistant/components/esphome/select.py +++ b/homeassistant/components/esphome/select.py @@ -53,9 +53,8 @@ class EsphomeSelect(EsphomeEntity[SelectInfo, SelectState], SelectEntity): @esphome_state_property def current_option(self) -> str | None: """Return the state of the entity.""" - if self._state.missing_state: - return None - return self._state.state + state = self._state + return None if state.missing_state else state.state async def async_select_option(self, option: str) -> None: """Change the selected option.""" diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index 37ab3123919..e5e78ca3bf1 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -2,9 +2,19 @@ from __future__ import annotations from asyncio import Event +from collections.abc import Callable +from typing import Any from unittest.mock import AsyncMock, Mock, patch -from aioesphomeapi import APIClient, APIVersion, DeviceInfo, ReconnectLogic +from aioesphomeapi import ( + APIClient, + APIVersion, + DeviceInfo, + EntityInfo, + EntityState, + ReconnectLogic, + UserService, +) import pytest from zeroconf import Zeroconf @@ -82,7 +92,7 @@ async def init_integration( @pytest.fixture -def mock_client(mock_device_info): +def mock_client(mock_device_info) -> APIClient: """Mock APIClient.""" mock_client = Mock(spec=APIClient) @@ -132,49 +142,72 @@ async def mock_dashboard(hass): yield data +async def _mock_generic_device_entry( + hass: HomeAssistant, + mock_client: APIClient, + mock_device_info: dict[str, Any], + mock_list_entities_services: tuple[list[EntityInfo], list[UserService]], + states: list[EntityState], +) -> MockConfigEntry: + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "test.local", + CONF_PORT: 6053, + CONF_PASSWORD: "", + }, + ) + entry.add_to_hass(hass) + + device_info = DeviceInfo( + name="test", + friendly_name="Test", + mac_address="11:22:33:44:55:aa", + esphome_version="1.0.0", + **mock_device_info, + ) + + async def _subscribe_states(callback: Callable[[EntityState], None]) -> None: + """Subscribe to state.""" + for state in states: + callback(state) + + mock_client.device_info = AsyncMock(return_value=device_info) + mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock()) + mock_client.list_entities_services = AsyncMock( + return_value=mock_list_entities_services + ) + mock_client.subscribe_states = _subscribe_states + + try_connect_done = Event() + real_try_connect = ReconnectLogic._try_connect + + async def mock_try_connect(self): + """Set an event when ReconnectLogic._try_connect has been awaited.""" + result = await real_try_connect(self) + try_connect_done.set() + return result + + with patch.object(ReconnectLogic, "_try_connect", mock_try_connect): + await hass.config_entries.async_setup(entry.entry_id) + await try_connect_done.wait() + + await hass.async_block_till_done() + + return entry + + @pytest.fixture async def mock_voice_assistant_entry( hass: HomeAssistant, - mock_client, -) -> MockConfigEntry: + mock_client: APIClient, +): """Set up an ESPHome entry with voice assistant.""" - async def _mock_voice_assistant_entry(version: int): - entry = MockConfigEntry( - domain=DOMAIN, - data={ - CONF_HOST: "test.local", - CONF_PORT: 6053, - CONF_PASSWORD: "", - }, + async def _mock_voice_assistant_entry(version: int) -> MockConfigEntry: + return await _mock_generic_device_entry( + hass, mock_client, {"voice_assistant_version": version}, ([], []), [] ) - entry.add_to_hass(hass) - - device_info = DeviceInfo( - name="test", - friendly_name="Test", - voice_assistant_version=version, - mac_address="11:22:33:44:55:aa", - esphome_version="1.0.0", - ) - - mock_client.device_info = AsyncMock(return_value=device_info) - mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock()) - - try_connect_done = Event() - real_try_connect = ReconnectLogic._try_connect - - async def mock_try_connect(self): - """Set an event when ReconnectLogic._try_connect has been awaited.""" - result = await real_try_connect(self) - try_connect_done.set() - return result - - with patch.object(ReconnectLogic, "_try_connect", mock_try_connect): - await hass.config_entries.async_setup(entry.entry_id) - await try_connect_done.wait() - - return entry return _mock_voice_assistant_entry @@ -189,3 +222,22 @@ async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfi async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry: """Set up an ESPHome entry with voice assistant.""" return await mock_voice_assistant_entry(version=2) + + +@pytest.fixture +async def mock_generic_device_entry( + hass: HomeAssistant, +) -> MockConfigEntry: + """Set up an ESPHome entry.""" + + async def _mock_device_entry( + mock_client: APIClient, + entity_info: list[EntityInfo], + user_service: list[UserService], + states: list[EntityState], + ) -> MockConfigEntry: + return await _mock_generic_device_entry( + hass, mock_client, {}, (entity_info, user_service), states + ) + + return _mock_device_entry diff --git a/tests/components/esphome/test_select.py b/tests/components/esphome/test_select.py index dec321ced86..5f6974ec035 100644 --- a/tests/components/esphome/test_select.py +++ b/tests/components/esphome/test_select.py @@ -1,6 +1,16 @@ """Test ESPHome selects.""" +from unittest.mock import call + +from aioesphomeapi import APIClient, SelectInfo, SelectState + +from homeassistant.components.select import ( + ATTR_OPTION, + DOMAIN as SELECT_DOMAIN, + SERVICE_SELECT_OPTION, +) +from homeassistant.const import ATTR_ENTITY_ID from homeassistant.core import HomeAssistant @@ -13,3 +23,37 @@ async def test_pipeline_selector( state = hass.states.get("select.test_assist_pipeline") assert state is not None assert state.state == "preferred" + + +async def test_select_generic_entity( + hass: HomeAssistant, mock_client: APIClient, mock_generic_device_entry +) -> None: + """Test a generic select entity.""" + entity_info = [ + SelectInfo( + object_id="myselect", + key=1, + name="my select", + unique_id="my_select", + options=["a", "b"], + ) + ] + states = [SelectState(key=1, state="a")] + user_service = [] + await mock_generic_device_entry( + mock_client=mock_client, + entity_info=entity_info, + user_service=user_service, + states=states, + ) + state = hass.states.get("select.test_my_select") + assert state is not None + assert state.state == "a" + + await hass.services.async_call( + SELECT_DOMAIN, + SERVICE_SELECT_OPTION, + {ATTR_ENTITY_ID: "select.test_my_select", ATTR_OPTION: "b"}, + blocking=True, + ) + mock_client.select_command.assert_has_calls([call(1, "b")])