From 32ed45084a310107a1d2f49b8fa98a9cdf752690 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Thu, 27 Apr 2023 14:24:29 +1200 Subject: [PATCH] ESPHome voice assistant: Version 2 - Stream raw tts audio back to device for playback (#92052) * Send raw audio back * Update tests * More tests * Fix docstrings and remove unused patches * More tests * MORE * Only set raw for v2 --- homeassistant/components/esphome/__init__.py | 29 ++- .../components/esphome/voice_assistant.py | 163 ++++++++++---- tests/components/esphome/conftest.py | 35 +++ .../esphome/test_voice_assistant.py | 209 ++++++++++++++++-- 4 files changed, 367 insertions(+), 69 deletions(-) diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index 6ce5f656d6e..a68dd562af1 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -288,39 +288,46 @@ async def async_setup_entry( # noqa: C901 voice_assistant_udp_server: VoiceAssistantUDPServer | None = None - def handle_pipeline_event( + def _handle_pipeline_event( event_type: VoiceAssistantEventType, data: dict[str, str] | None ) -> None: - """Handle a voice assistant pipeline event.""" cli.send_voice_assistant_event(event_type, data) - async def handle_pipeline_start() -> int | None: + def _handle_pipeline_finished() -> None: + nonlocal voice_assistant_udp_server + + entry_data.async_set_assist_pipeline_state(False) + + if voice_assistant_udp_server is not None: + voice_assistant_udp_server.close() + voice_assistant_udp_server = None + + async def _handle_pipeline_start() -> int | None: """Start a voice assistant pipeline.""" nonlocal voice_assistant_udp_server if voice_assistant_udp_server is not None: return None - voice_assistant_udp_server = VoiceAssistantUDPServer(hass, entry_data) + voice_assistant_udp_server = VoiceAssistantUDPServer( + hass, entry_data, _handle_pipeline_event, _handle_pipeline_finished + ) port = await voice_assistant_udp_server.start_server() hass.async_create_background_task( - voice_assistant_udp_server.run_pipeline(handle_pipeline_event), + voice_assistant_udp_server.run_pipeline(), "esphome.voice_assistant_udp_server.run_pipeline", ) entry_data.async_set_assist_pipeline_state(True) return port - async def handle_pipeline_stop() -> None: + async def _handle_pipeline_stop() -> None: """Stop a voice assistant pipeline.""" nonlocal voice_assistant_udp_server - entry_data.async_set_assist_pipeline_state(False) - if voice_assistant_udp_server is not None: voice_assistant_udp_server.stop() - voice_assistant_udp_server = None async def on_connect() -> None: """Subscribe to states and list entities on successful API login.""" @@ -369,8 +376,8 @@ async def async_setup_entry( # noqa: C901 if device_info.voice_assistant_version: entry_data.disconnect_callbacks.append( await cli.subscribe_voice_assistant( - handle_pipeline_start, - handle_pipeline_stop, + _handle_pipeline_start, + _handle_pipeline_stop, ) ) diff --git a/homeassistant/components/esphome/voice_assistant.py b/homeassistant/components/esphome/voice_assistant.py index b6c76e00f4c..aaa2dc80a78 100644 --- a/homeassistant/components/esphome/voice_assistant.py +++ b/homeassistant/components/esphome/voice_assistant.py @@ -8,8 +8,9 @@ import socket from typing import cast from aioesphomeapi import VoiceAssistantEventType +import async_timeout -from homeassistant.components import stt +from homeassistant.components import stt, tts from homeassistant.components.assist_pipeline import ( PipelineEvent, PipelineEventType, @@ -26,6 +27,7 @@ from .enum_mapper import EsphomeEnumMapper _LOGGER = logging.getLogger(__name__) UDP_PORT = 0 # Set to 0 to let the OS pick a free random port +UDP_MAX_PACKET_SIZE = 1024 _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[ VoiceAssistantEventType, PipelineEventType @@ -50,11 +52,14 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): started = False queue: asyncio.Queue[bytes] | None = None transport: asyncio.DatagramTransport | None = None + remote_addr: tuple[str, int] | None = None def __init__( self, hass: HomeAssistant, entry_data: RuntimeEntryData, + handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], + handle_finished: Callable[[], None], ) -> None: """Initialize UDP receiver.""" self.context = Context() @@ -64,6 +69,9 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): self.device_info = entry_data.device_info self.queue = asyncio.Queue() + self.handle_event = handle_event + self.handle_finished = handle_finished + self._tts_done = asyncio.Event() async def start_server(self) -> int: """Start accepting connections.""" @@ -97,6 +105,10 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): @callback def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: """Handle incoming UDP packet.""" + if not self.started: + return + if self.remote_addr is None: + self.remote_addr = addr if self.queue is not None: self.queue.put_nowait(data) @@ -106,12 +118,18 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): (Other than BlockingIOError or InterruptedError.) """ _LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc) + self.handle_finished() @callback def stop(self) -> None: """Stop the receiver.""" if self.queue is not None: self.queue.put_nowait(b"") + self.started = False + + def close(self) -> None: + """Close the receiver.""" + if self.queue is not None: self.queue = None if self.transport is not None: self.transport.close() @@ -124,57 +142,112 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): while data := await self.queue.get(): yield data + def _event_callback(self, event: PipelineEvent) -> None: + """Handle pipeline events.""" + + try: + event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type) + except KeyError: + _LOGGER.warning("Received unknown pipeline event type: %s", event.type) + return + + data_to_send = None + if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: + assert event.data is not None + data_to_send = {"text": event.data["stt_output"]["text"]} + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: + assert event.data is not None + data_to_send = {"text": event.data["tts_input"]} + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: + assert event.data is not None + path = event.data["tts_output"]["url"] + url = async_process_play_media_url(self.hass, path) + data_to_send = {"url": url} + + if self.device_info.voice_assistant_version >= 2: + media_id = event.data["tts_output"]["media_id"] + self.hass.async_create_background_task( + self._send_tts(media_id), "esphome_voice_assistant_tts" + ) + else: + self._tts_done.set() + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: + assert event.data is not None + data_to_send = { + "code": event.data["code"], + "message": event.data["message"], + } + self.handle_finished() + + self.handle_event(event_type, data_to_send) + async def run_pipeline( self, - handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], + pipeline_timeout: float = 30.0, ) -> None: """Run the Voice Assistant pipeline.""" + try: + tts_audio_output = ( + "raw" if self.device_info.voice_assistant_version >= 2 else "mp3" + ) + async with async_timeout.timeout(pipeline_timeout): + await async_pipeline_from_audio_stream( + self.hass, + context=self.context, + event_callback=self._event_callback, + stt_metadata=stt.SpeechMetadata( + language="", # set in async_pipeline_from_audio_stream + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=self._iterate_packets(), + pipeline_id=pipeline_select.get_chosen_pipeline( + self.hass, DOMAIN, self.device_info.mac_address + ), + tts_audio_output=tts_audio_output, + ) - @callback - def handle_pipeline_event(event: PipelineEvent) -> None: - """Handle pipeline events.""" + # Block until TTS is done sending + await self._tts_done.wait() - try: - event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type) - except KeyError: - _LOGGER.warning("Received unknown pipeline event type: %s", event.type) + _LOGGER.debug("Pipeline finished") + except asyncio.TimeoutError: + _LOGGER.warning("Pipeline timeout") + finally: + self.handle_finished() + + async def _send_tts(self, media_id: str) -> None: + """Send TTS audio to device via UDP.""" + try: + if self.transport is None: return - data_to_send = None - if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: - assert event.data is not None - data_to_send = {"text": event.data["stt_output"]["text"]} - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: - assert event.data is not None - data_to_send = {"text": event.data["tts_input"]} - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: - assert event.data is not None - path = event.data["tts_output"]["url"] - url = async_process_play_media_url(self.hass, path) - data_to_send = {"url": url} - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: - assert event.data is not None - data_to_send = { - "code": event.data["code"], - "message": event.data["message"], - } + _extension, audio_bytes = await tts.async_get_media_source_audio( + self.hass, + media_id, + ) - handle_event(event_type, data_to_send) + _LOGGER.debug("Sending %d bytes of audio", len(audio_bytes)) - await async_pipeline_from_audio_stream( - self.hass, - context=self.context, - event_callback=handle_pipeline_event, - stt_metadata=stt.SpeechMetadata( - language="", # set in async_pipeline_from_audio_stream - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, - ), - stt_stream=self._iterate_packets(), - pipeline_id=pipeline_select.get_chosen_pipeline( - self.hass, DOMAIN, self.device_info.mac_address - ), - ) + bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8 + sample_offset = 0 + samples_left = len(audio_bytes) // bytes_per_sample + + while samples_left > 0: + bytes_offset = sample_offset * bytes_per_sample + chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024] + samples_in_chunk = len(chunk) // bytes_per_sample + samples_left -= samples_in_chunk + + self.transport.sendto(chunk, self.remote_addr) + await asyncio.sleep( + samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.99 + ) + + sample_offset += samples_in_chunk + + finally: + self._tts_done.set() diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index f5362b1fb3d..a70686acbf6 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -157,3 +157,38 @@ async def mock_voice_assistant_v1_entry( await hass.async_block_till_done() return entry + + +@pytest.fixture +async def mock_voice_assistant_v2_entry( + hass: HomeAssistant, + mock_client, +) -> MockConfigEntry: + """Set up an ESPHome entry with voice assistant.""" + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "test.local", + CONF_PORT: 6053, + CONF_PASSWORD: "", + }, + ) + entry.add_to_hass(hass) + + device_info = DeviceInfo( + name="test", + friendly_name="Test", + voice_assistant_version=2, + mac_address="11:22:33:44:55:aa", + esphome_version="1.0.0", + ) + + mock_client.device_info = AsyncMock(return_value=device_info) + mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock()) + + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + await hass.async_block_till_done() + await hass.async_block_till_done() + + return entry diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py index e1fe41829c2..fed83f8ab10 100644 --- a/tests/components/esphome/test_voice_assistant.py +++ b/tests/components/esphome/test_voice_assistant.py @@ -4,10 +4,12 @@ import asyncio import socket from unittest.mock import Mock, patch +from aioesphomeapi import VoiceAssistantEventType import async_timeout import pytest -from homeassistant.components import assist_pipeline, esphome +from homeassistant.components import esphome +from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType from homeassistant.components.esphome import DomainData from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer from homeassistant.core import HomeAssistant @@ -15,6 +17,7 @@ from homeassistant.core import HomeAssistant _TEST_INPUT_TEXT = "This is an input test" _TEST_OUTPUT_TEXT = "This is an output test" _TEST_OUTPUT_URL = "output.mp3" +_TEST_MEDIA_ID = "12345" @pytest.fixture @@ -24,11 +27,40 @@ def voice_assistant_udp_server_v1( ) -> VoiceAssistantUDPServer: """Return the UDP server.""" entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_v1_entry) - return VoiceAssistantUDPServer(hass, entry_data) + + server: VoiceAssistantUDPServer = None + + def handle_finished(): + nonlocal server + assert server is not None + server.close() + + server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished) + return server + + +@pytest.fixture +def voice_assistant_udp_server_v2( + hass: HomeAssistant, + mock_voice_assistant_v2_entry, +) -> VoiceAssistantUDPServer: + """Return the UDP server.""" + entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_v2_entry) + + server: VoiceAssistantUDPServer = None + + def handle_finished(): + nonlocal server + assert server is not None + server.close() + + server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished) + return server async def test_pipeline_events( - hass: HomeAssistant, voice_assistant_udp_server_v1: VoiceAssistantUDPServer + hass: HomeAssistant, + voice_assistant_udp_server_v1: VoiceAssistantUDPServer, ) -> None: """Test that the pipeline function is called.""" @@ -37,29 +69,29 @@ async def test_pipeline_events( # Fake events event_callback( - assist_pipeline.PipelineEvent( - type=assist_pipeline.PipelineEventType.STT_START, + PipelineEvent( + type=PipelineEventType.STT_START, data={}, ) ) event_callback( - assist_pipeline.PipelineEvent( - type=assist_pipeline.PipelineEventType.STT_END, + PipelineEvent( + type=PipelineEventType.STT_END, data={"stt_output": {"text": _TEST_INPUT_TEXT}}, ) ) event_callback( - assist_pipeline.PipelineEvent( - type=assist_pipeline.PipelineEventType.TTS_START, + PipelineEvent( + type=PipelineEventType.TTS_START, data={"tts_input": _TEST_OUTPUT_TEXT}, ) ) event_callback( - assist_pipeline.PipelineEvent( - type=assist_pipeline.PipelineEventType.TTS_END, + PipelineEvent( + type=PipelineEventType.TTS_END, data={"tts_output": {"url": _TEST_OUTPUT_URL}}, ) ) @@ -77,13 +109,15 @@ async def test_pipeline_events( assert data is not None assert data["url"] == _TEST_OUTPUT_URL + voice_assistant_udp_server_v1.handle_event = handle_event + with patch( "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ): voice_assistant_udp_server_v1.transport = Mock() - await voice_assistant_udp_server_v1.run_pipeline(handle_event) + await voice_assistant_udp_server_v1.run_pipeline() async def test_udp_server( @@ -114,10 +148,61 @@ async def test_udp_server( assert voice_assistant_udp_server_v1.queue.qsize() == 1 voice_assistant_udp_server_v1.stop() + voice_assistant_udp_server_v1.close() assert voice_assistant_udp_server_v1.transport.is_closing() +async def test_udp_server_queue( + hass: HomeAssistant, + voice_assistant_udp_server_v1: VoiceAssistantUDPServer, +) -> None: + """Test the UDP server queues incoming data.""" + + voice_assistant_udp_server_v1.started = True + + assert voice_assistant_udp_server_v1.queue.qsize() == 0 + + voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0)) + assert voice_assistant_udp_server_v1.queue.qsize() == 1 + + voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0)) + assert voice_assistant_udp_server_v1.queue.qsize() == 2 + + async for data in voice_assistant_udp_server_v1._iterate_packets(): + assert data == bytes(1024) + break + assert voice_assistant_udp_server_v1.queue.qsize() == 1 # One message removed + + voice_assistant_udp_server_v1.stop() + assert ( + voice_assistant_udp_server_v1.queue.qsize() == 2 + ) # An empty message added by stop + + voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0)) + assert ( + voice_assistant_udp_server_v1.queue.qsize() == 2 + ) # No new messages added after stop + + voice_assistant_udp_server_v1.close() + + with pytest.raises(RuntimeError): + async for data in voice_assistant_udp_server_v1._iterate_packets(): + assert data == bytes(1024) + + +async def test_error_calls_handle_finished( + hass: HomeAssistant, + voice_assistant_udp_server_v1: VoiceAssistantUDPServer, +) -> None: + """Test that the handle_finished callback is called when an error occurs.""" + voice_assistant_udp_server_v1.handle_finished = Mock() + + voice_assistant_udp_server_v1.error_received(Exception()) + + voice_assistant_udp_server_v1.handle_finished.assert_called() + + async def test_udp_server_multiple( hass: HomeAssistant, socket_enabled, @@ -146,9 +231,107 @@ async def test_udp_server_after_stopped( voice_assistant_udp_server_v1: VoiceAssistantUDPServer, ) -> None: """Test that the UDP server raises an error if started after stopped.""" - voice_assistant_udp_server_v1.stop() + voice_assistant_udp_server_v1.close() with patch( "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=unused_udp_port_factory(), ), pytest.raises(RuntimeError): await voice_assistant_udp_server_v1.start_server() + + +async def test_unknown_event_type( + hass: HomeAssistant, + voice_assistant_udp_server_v1: VoiceAssistantUDPServer, +) -> None: + """Test the UDP server does not call handle_event for unknown events.""" + voice_assistant_udp_server_v1._event_callback( + PipelineEvent( + type="unknown-event", + data={}, + ) + ) + + assert not voice_assistant_udp_server_v1.handle_event.called + + +async def test_error_event_type( + hass: HomeAssistant, + voice_assistant_udp_server_v1: VoiceAssistantUDPServer, +) -> None: + """Test the UDP server calls event handler with error.""" + voice_assistant_udp_server_v1._event_callback( + PipelineEvent( + type=PipelineEventType.ERROR, + data={"code": "code", "message": "message"}, + ) + ) + + assert voice_assistant_udp_server_v1.handle_event.called_with( + VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, + {"code": "code", "message": "message"}, + ) + + +async def test_send_tts_not_called( + hass: HomeAssistant, + voice_assistant_udp_server_v1: VoiceAssistantUDPServer, +) -> None: + """Test the UDP server with a v1 device does not call _send_tts.""" + with patch( + "homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts" + ) as mock_send_tts: + voice_assistant_udp_server_v1._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} + }, + ) + ) + + mock_send_tts.assert_not_called() + + +async def test_send_tts_called( + hass: HomeAssistant, + voice_assistant_udp_server_v2: VoiceAssistantUDPServer, +) -> None: + """Test the UDP server with a v2 device calls _send_tts.""" + with patch( + "homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts" + ) as mock_send_tts: + voice_assistant_udp_server_v2._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} + }, + ) + ) + + mock_send_tts.assert_called_with(_TEST_MEDIA_ID) + + +async def test_send_tts( + hass: HomeAssistant, + voice_assistant_udp_server_v2: VoiceAssistantUDPServer, +) -> None: + """Test the UDP server calls sendto to transmit audio data to device.""" + with patch( + "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", + return_value=("raw", bytes(1024)), + ): + voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) + + voice_assistant_udp_server_v2._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} + }, + ) + ) + + await voice_assistant_udp_server_v2._tts_done.wait() + + voice_assistant_udp_server_v2.transport.sendto.assert_called()