From 3347a3f8a678fda1ee0c7b9ff332c82873a41682 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Fri, 2 Feb 2024 20:26:44 -0600 Subject: [PATCH] More thorough checks in ESPHome voice assistant UDP server (#109394) * More thorough checks in UDP server * Simplify and change to stop_requested * Check transport --- homeassistant/components/esphome/manager.py | 1 - .../components/esphome/voice_assistant.py | 31 ++-- .../esphome/test_voice_assistant.py | 152 ++++++++++++------ 3 files changed, 122 insertions(+), 62 deletions(-) diff --git a/homeassistant/components/esphome/manager.py b/homeassistant/components/esphome/manager.py index f197574c30a..59f37d3a078 100644 --- a/homeassistant/components/esphome/manager.py +++ b/homeassistant/components/esphome/manager.py @@ -352,7 +352,6 @@ class ESPHomeManager: if self.voice_assistant_udp_server is not None: _LOGGER.warning("Voice assistant UDP server was not stopped") self.voice_assistant_udp_server.stop() - self.voice_assistant_udp_server.close() self.voice_assistant_udp_server = None hass = self.hass diff --git a/homeassistant/components/esphome/voice_assistant.py b/homeassistant/components/esphome/voice_assistant.py index de6b521d980..7c5c74d58ee 100644 --- a/homeassistant/components/esphome/voice_assistant.py +++ b/homeassistant/components/esphome/voice_assistant.py @@ -1,4 +1,5 @@ """ESPHome voice assistant support.""" + from __future__ import annotations import asyncio @@ -67,7 +68,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): """Receive UDP packets and forward them to the voice assistant.""" started = False - stopped = False + stop_requested = False transport: asyncio.DatagramTransport | None = None remote_addr: tuple[str, int] | None = None @@ -92,6 +93,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): self._tts_done = asyncio.Event() self._tts_task: asyncio.Task | None = None + @property + def is_running(self) -> bool: + """True if the the UDP server is started and hasn't been asked to stop.""" + return self.started and (not self.stop_requested) + async def start_server(self) -> int: """Start accepting connections.""" @@ -99,7 +105,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): """Accept connection.""" if self.started: raise RuntimeError("Can only start once") - if self.stopped: + if self.stop_requested: raise RuntimeError("No longer accepting connections") self.started = True @@ -124,7 +130,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): @callback def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: """Handle incoming UDP packet.""" - if not self.started or self.stopped: + if not self.is_running: return if self.remote_addr is None: self.remote_addr = addr @@ -142,19 +148,19 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): def stop(self) -> None: """Stop the receiver.""" self.queue.put_nowait(b"") - self.started = False - self.stopped = True + self.close() def close(self) -> None: """Close the receiver.""" self.started = False - self.stopped = True + self.stop_requested = True + if self.transport is not None: self.transport.close() async def _iterate_packets(self) -> AsyncIterable[bytes]: """Iterate over incoming packets.""" - if not self.started or self.stopped: + if not self.is_running: raise RuntimeError("Not running") while data := await self.queue.get(): @@ -303,8 +309,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): async def _send_tts(self, media_id: str) -> None: """Send TTS audio to device via UDP.""" + # Always send stream start/end events + self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}) + try: - if self.transport is None: + if (not self.is_running) or (self.transport is None): return extension, data = await tts.async_get_media_source_audio( @@ -337,15 +346,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): _LOGGER.debug("Sending %d bytes of audio", audio_bytes_size) - self.handle_event( - VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {} - ) - bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8 sample_offset = 0 samples_left = audio_bytes_size // bytes_per_sample - while samples_left > 0: + while (samples_left > 0) and self.is_running: bytes_offset = sample_offset * bytes_per_sample chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024] samples_in_chunk = len(chunk) // bytes_per_sample diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py index 38a33bfdec2..f6665c4ad91 100644 --- a/tests/components/esphome/test_voice_assistant.py +++ b/tests/components/esphome/test_voice_assistant.py @@ -70,6 +70,19 @@ def voice_assistant_udp_server_v2( return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry) +@pytest.fixture +def test_wav() -> bytes: + """Return one second of empty WAV audio.""" + with io.BytesIO() as wav_io: + with wave.open(wav_io, "wb") as wav_file: + wav_file.setframerate(16000) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + wav_file.writeframes(bytes(_ONE_SECOND)) + + return wav_io.getvalue() + + async def test_pipeline_events( hass: HomeAssistant, voice_assistant_udp_server_v1: VoiceAssistantUDPServer, @@ -241,11 +254,13 @@ async def test_udp_server_multiple( ): await voice_assistant_udp_server_v1.start_server() - with patch( - "homeassistant.components.esphome.voice_assistant.UDP_PORT", - new=unused_udp_port_factory(), - ), pytest.raises(RuntimeError): - pass + with ( + patch( + "homeassistant.components.esphome.voice_assistant.UDP_PORT", + new=unused_udp_port_factory(), + ), + pytest.raises(RuntimeError), + ): await voice_assistant_udp_server_v1.start_server() @@ -257,10 +272,13 @@ async def test_udp_server_after_stopped( ) -> None: """Test that the UDP server raises an error if started after stopped.""" voice_assistant_udp_server_v1.close() - with patch( - "homeassistant.components.esphome.voice_assistant.UDP_PORT", - new=unused_udp_port_factory(), - ), pytest.raises(RuntimeError): + with ( + patch( + "homeassistant.components.esphome.voice_assistant.UDP_PORT", + new=unused_udp_port_factory(), + ), + pytest.raises(RuntimeError), + ): await voice_assistant_udp_server_v1.start_server() @@ -362,35 +380,33 @@ async def test_send_tts_not_called_when_empty( async def test_send_tts( hass: HomeAssistant, voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + test_wav, ) -> None: """Test the UDP server calls sendto to transmit audio data to device.""" - with io.BytesIO() as wav_io: - with wave.open(wav_io, "wb") as wav_file: - wav_file.setframerate(16000) - wav_file.setsampwidth(2) - wav_file.setnchannels(1) - wav_file.writeframes(bytes(_ONE_SECOND)) - - wav_bytes = wav_io.getvalue() - with patch( "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("wav", wav_bytes), + return_value=("wav", test_wav), ): + voice_assistant_udp_server_v2.started = True voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) - - voice_assistant_udp_server_v2._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} - }, + with patch.object( + voice_assistant_udp_server_v2.transport, "is_closing", return_value=False + ): + voice_assistant_udp_server_v2._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": { + "media_id": _TEST_MEDIA_ID, + "url": _TEST_OUTPUT_URL, + } + }, + ) ) - ) - await voice_assistant_udp_server_v2._tts_done.wait() + await voice_assistant_udp_server_v2._tts_done.wait() - voice_assistant_udp_server_v2.transport.sendto.assert_called() + voice_assistant_udp_server_v2.transport.sendto.assert_called() async def test_send_tts_wrong_sample_rate( @@ -400,17 +416,20 @@ async def test_send_tts_wrong_sample_rate( """Test the UDP server calls sendto to transmit audio data to device.""" with io.BytesIO() as wav_io: with wave.open(wav_io, "wb") as wav_file: - wav_file.setframerate(22050) # should be 16000 + wav_file.setframerate(22050) wav_file.setsampwidth(2) wav_file.setnchannels(1) wav_file.writeframes(bytes(_ONE_SECOND)) wav_bytes = wav_io.getvalue() - - with patch( - "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("wav", wav_bytes), - ), pytest.raises(ValueError): + with ( + patch( + "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", + return_value=("wav", wav_bytes), + ), + pytest.raises(ValueError), + ): + voice_assistant_udp_server_v2.started = True voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_udp_server_v2._event_callback( @@ -431,10 +450,14 @@ async def test_send_tts_wrong_format( voice_assistant_udp_server_v2: VoiceAssistantUDPServer, ) -> None: """Test that only WAV audio will be streamed.""" - with patch( - "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("raw", bytes(1024)), - ), pytest.raises(ValueError): + with ( + patch( + "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", + return_value=("raw", bytes(1024)), + ), + pytest.raises(ValueError), + ): + voice_assistant_udp_server_v2.started = True voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_udp_server_v2._event_callback( @@ -450,6 +473,33 @@ async def test_send_tts_wrong_format( await voice_assistant_udp_server_v2._tts_task # raises ValueError +async def test_send_tts_not_started( + hass: HomeAssistant, + voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + test_wav, +) -> None: + """Test the UDP server does not call sendto when not started.""" + with patch( + "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", + return_value=("wav", test_wav), + ): + voice_assistant_udp_server_v2.started = False + voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) + + voice_assistant_udp_server_v2._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} + }, + ) + ) + + await voice_assistant_udp_server_v2._tts_done.wait() + + voice_assistant_udp_server_v2.transport.sendto.assert_not_called() + + async def test_wake_word( hass: HomeAssistant, voice_assistant_udp_server_v2: VoiceAssistantUDPServer, @@ -459,11 +509,12 @@ async def test_wake_word( async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs): assert start_stage == PipelineStage.WAKE_WORD - with patch( - "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ), patch( - "asyncio.Event.wait" # TTS wait event + with ( + patch( + "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + patch("asyncio.Event.wait"), # TTS wait event ): voice_assistant_udp_server_v2.transport = Mock() @@ -515,10 +566,15 @@ async def test_wake_word_abort_exception( async def async_pipeline_from_audio_stream(*args, **kwargs): raise WakeWordDetectionAborted - with patch( - "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ), patch.object(voice_assistant_udp_server_v2, "handle_event") as mock_handle_event: + with ( + patch( + "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + patch.object( + voice_assistant_udp_server_v2, "handle_event" + ) as mock_handle_event, + ): voice_assistant_udp_server_v2.transport = Mock() await voice_assistant_udp_server_v2.run_pipeline(