From 4a417c7dcc89a7e1102703afb0a3e60bb43fbba1 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Wed, 23 Aug 2023 11:11:14 -0500 Subject: [PATCH] Wake word entity state/category fix (#98886) * Only change wake word entity state on detection * Wake word entity is diagnostic --- .../components/wake_word/__init__.py | 21 +++++++++++------- .../wake_word/snapshots/test_init.ambr | 3 +++ tests/components/wake_word/test_init.py | 22 ++++++++++++++++++- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/homeassistant/components/wake_word/__init__.py b/homeassistant/components/wake_word/__init__.py index 895dababd54..0a751b7eea2 100644 --- a/homeassistant/components/wake_word/__init__.py +++ b/homeassistant/components/wake_word/__init__.py @@ -7,7 +7,7 @@ import logging from typing import final from homeassistant.config_entries import ConfigEntry -from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN +from homeassistant.const import STATE_UNAVAILABLE, STATE_UNKNOWN, EntityCategory from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv from homeassistant.helpers.entity_component import EntityComponent @@ -71,16 +71,17 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: class WakeWordDetectionEntity(RestoreEntity): """Represent a single wake word provider.""" + _attr_entity_category = EntityCategory.DIAGNOSTIC _attr_should_poll = False - __last_processed: str | None = None + __last_detected: str | None = None @property @final def state(self) -> str | None: """Return the state of the entity.""" - if self.__last_processed is None: + if self.__last_detected is None: return None - return self.__last_processed + return self.__last_detected @property @abstractmethod @@ -103,9 +104,13 @@ class WakeWordDetectionEntity(RestoreEntity): Audio must be 16Khz sample rate with 16-bit mono PCM samples. """ - self.__last_processed = dt_util.utcnow().isoformat() - self.async_write_ha_state() - return await self._async_process_audio_stream(stream) + result = await self._async_process_audio_stream(stream) + if result is not None: + # Update last detected only when there is a detection + self.__last_detected = dt_util.utcnow().isoformat() + self.async_write_ha_state() + + return result async def async_internal_added_to_hass(self) -> None: """Call when the entity is added to hass.""" @@ -116,4 +121,4 @@ class WakeWordDetectionEntity(RestoreEntity): and state.state is not None and state.state not in (STATE_UNAVAILABLE, STATE_UNKNOWN) ): - self.__last_processed = state.state + self.__last_detected = state.state diff --git a/tests/components/wake_word/snapshots/test_init.ambr b/tests/components/wake_word/snapshots/test_init.ambr index ca6d5d950f0..cf7c09cd730 100644 --- a/tests/components/wake_word/snapshots/test_init.ambr +++ b/tests/components/wake_word/snapshots/test_init.ambr @@ -1,4 +1,7 @@ # serializer version: 1 +# name: test_detected_entity + None +# --- # name: test_ws_detect dict({ 'event': dict({ diff --git a/tests/components/wake_word/test_init.py b/tests/components/wake_word/test_init.py index 954cbe6dc8c..d37cb3aa540 100644 --- a/tests/components/wake_word/test_init.py +++ b/tests/components/wake_word/test_init.py @@ -3,9 +3,11 @@ from collections.abc import AsyncIterable, Generator from pathlib import Path import pytest +from syrupy.assertion import SnapshotAssertion from homeassistant.components import wake_word from homeassistant.config_entries import ConfigEntry, ConfigEntryState, ConfigFlow +from homeassistant.const import EntityCategory from homeassistant.core import HomeAssistant, State from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.setup import async_setup_component @@ -147,7 +149,10 @@ async def test_config_entry_unload( async def test_detected_entity( - hass: HomeAssistant, tmp_path: Path, setup: MockProviderEntity + hass: HomeAssistant, + tmp_path: Path, + setup: MockProviderEntity, + snapshot: SnapshotAssertion, ) -> None: """Test successful detection through entity.""" @@ -158,9 +163,13 @@ async def test_detected_entity( timestamp += _MS_PER_CHUNK # Need 2 seconds to trigger + state = setup.state result = await setup.async_process_audio_stream(three_second_stream()) assert result == wake_word.DetectionResult("test_ww", 2048) + assert state != setup.state + assert state == snapshot + async def test_not_detected_entity( hass: HomeAssistant, setup: MockProviderEntity @@ -174,9 +183,13 @@ async def test_not_detected_entity( timestamp += _MS_PER_CHUNK # Need 2 seconds to trigger + state = setup.state result = await setup.async_process_audio_stream(one_second_stream()) assert result is None + # State should only change when there's a detection + assert state == setup.state + async def test_default_engine_none(hass: HomeAssistant, tmp_path: Path) -> None: """Test async_default_engine.""" @@ -224,3 +237,10 @@ async def test_restore_state( state = hass.states.get(entity_id) assert state assert state.state == timestamp + + +async def test_entity_attributes( + hass: HomeAssistant, mock_provider_entity: MockProviderEntity +) -> None: + """Test that the provider entity attributes match expectations.""" + assert mock_provider_entity.entity_category == EntityCategory.DIAGNOSTIC