diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index a68dd562af1..f95763d3a6c 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -302,7 +302,7 @@ async def async_setup_entry( # noqa: C901 voice_assistant_udp_server.close() voice_assistant_udp_server = None - async def _handle_pipeline_start() -> int | None: + async def _handle_pipeline_start(conversation_id: str, use_vad: bool) -> int | None: """Start a voice assistant pipeline.""" nonlocal voice_assistant_udp_server @@ -315,7 +315,10 @@ async def async_setup_entry( # noqa: C901 port = await voice_assistant_udp_server.start_server() hass.async_create_background_task( - voice_assistant_udp_server.run_pipeline(), + voice_assistant_udp_server.run_pipeline( + conversation_id=conversation_id or None, + use_vad=use_vad, + ), "esphome.voice_assistant_udp_server.run_pipeline", ) entry_data.async_set_assist_pipeline_state(True) diff --git a/homeassistant/components/esphome/manifest.json b/homeassistant/components/esphome/manifest.json index 49057080469..c6e430d7845 100644 --- a/homeassistant/components/esphome/manifest.json +++ b/homeassistant/components/esphome/manifest.json @@ -15,7 +15,7 @@ "iot_class": "local_push", "loggers": ["aioesphomeapi", "noiseprotocol"], "requirements": [ - "aioesphomeapi==13.7.5", + "aioesphomeapi==13.9.0", "bluetooth-data-tools==0.4.0", "esphome-dashboard-api==1.2.3" ], diff --git a/homeassistant/components/esphome/voice_assistant.py b/homeassistant/components/esphome/voice_assistant.py index aaa2dc80a78..5cd9c1d931f 100644 --- a/homeassistant/components/esphome/voice_assistant.py +++ b/homeassistant/components/esphome/voice_assistant.py @@ -2,7 +2,8 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterable, Callable +from collections import deque +from collections.abc import AsyncIterable, Callable, MutableSequence, Sequence import logging import socket from typing import cast @@ -17,6 +18,7 @@ from homeassistant.components.assist_pipeline import ( async_pipeline_from_audio_stream, select as pipeline_select, ) +from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter from homeassistant.components.media_player import async_process_play_media_url from homeassistant.core import Context, HomeAssistant, callback @@ -50,7 +52,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): """Receive UDP packets and forward them to the voice assistant.""" started = False - queue: asyncio.Queue[bytes] | None = None + stopped = False transport: asyncio.DatagramTransport | None = None remote_addr: tuple[str, int] | None = None @@ -60,6 +62,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): entry_data: RuntimeEntryData, handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], handle_finished: Callable[[], None], + audio_timeout: float = 2.0, ) -> None: """Initialize UDP receiver.""" self.context = Context() @@ -68,10 +71,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): assert entry_data.device_info is not None self.device_info = entry_data.device_info - self.queue = asyncio.Queue() + self.queue: asyncio.Queue[bytes] = asyncio.Queue() self.handle_event = handle_event self.handle_finished = handle_finished self._tts_done = asyncio.Event() + self.audio_timeout = audio_timeout async def start_server(self) -> int: """Start accepting connections.""" @@ -80,7 +84,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): """Accept connection.""" if self.started: raise RuntimeError("Can only start once") - if self.queue is None: + if self.stopped: raise RuntimeError("No longer accepting connections") self.started = True @@ -105,12 +109,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): @callback def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: """Handle incoming UDP packet.""" - if not self.started: + if not self.started or self.stopped: return if self.remote_addr is None: self.remote_addr = addr - if self.queue is not None: - self.queue.put_nowait(data) + self.queue.put_nowait(data) def error_received(self, exc: Exception) -> None: """Handle when a send or receive operation raises an OSError. @@ -123,21 +126,21 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): @callback def stop(self) -> None: """Stop the receiver.""" - if self.queue is not None: - self.queue.put_nowait(b"") + self.queue.put_nowait(b"") self.started = False + self.stopped = True def close(self) -> None: """Close the receiver.""" - if self.queue is not None: - self.queue = None + self.started = False + self.stopped = True if self.transport is not None: self.transport.close() async def _iterate_packets(self) -> AsyncIterable[bytes]: """Iterate over incoming packets.""" - if self.queue is None: - raise RuntimeError("Already stopped") + if not self.started or self.stopped: + raise RuntimeError("Not running") while data := await self.queue.get(): yield data @@ -152,9 +155,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): return data_to_send = None + error = False if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: assert event.data is not None data_to_send = {"text": event.data["stt_output"]["text"]} + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: + assert event.data is not None + data_to_send = { + "conversation_id": event.data["intent_output"]["conversation_id"] or "", + } elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: assert event.data is not None data_to_send = {"text": event.data["tts_input"]} @@ -177,19 +186,132 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): "code": event.data["code"], "message": event.data["message"], } - self.handle_finished() + self._tts_done.set() + error = True self.handle_event(event_type, data_to_send) + if error: + self.handle_finished() + + async def _wait_for_speech( + self, + segmenter: VoiceCommandSegmenter, + chunk_buffer: MutableSequence[bytes], + ) -> bool: + """Buffer audio chunks until speech is detected. + + Raises asyncio.TimeoutError if no audio data is retrievable from the queue (device stops sending packets / networking issue). + + Returns True if speech was detected + Returns False if the connection was stopped gracefully (b"" put onto the queue). + """ + # Timeout if no audio comes in for a while. + async with async_timeout.timeout(self.audio_timeout): + chunk = await self.queue.get() + + while chunk: + segmenter.process(chunk) + # Buffer the data we have taken from the queue + chunk_buffer.append(chunk) + if segmenter.in_command: + return True + + async with async_timeout.timeout(self.audio_timeout): + chunk = await self.queue.get() + + # If chunk is falsey, `stop()` was called + return False + + async def _segment_audio( + self, + segmenter: VoiceCommandSegmenter, + chunk_buffer: Sequence[bytes], + ) -> AsyncIterable[bytes]: + """Yield audio chunks until voice command has finished. + + Raises asyncio.TimeoutError if no audio data is retrievable from the queue. + """ + # Buffered chunks first + for buffered_chunk in chunk_buffer: + yield buffered_chunk + + # Timeout if no audio comes in for a while. + async with async_timeout.timeout(self.audio_timeout): + chunk = await self.queue.get() + + while chunk: + if not segmenter.process(chunk): + # Voice command is finished + break + + yield chunk + + async with async_timeout.timeout(self.audio_timeout): + chunk = await self.queue.get() + + async def _iterate_packets_with_vad( + self, pipeline_timeout: float + ) -> Callable[[], AsyncIterable[bytes]] | None: + segmenter = VoiceCommandSegmenter() + chunk_buffer: deque[bytes] = deque(maxlen=100) + try: + async with async_timeout.timeout(pipeline_timeout): + speech_detected = await self._wait_for_speech(segmenter, chunk_buffer) + if not speech_detected: + _LOGGER.debug( + "Device stopped sending audio before speech was detected" + ) + self.handle_finished() + return None + except asyncio.TimeoutError: + self.handle_event( + VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, + { + "code": "speech-timeout", + "message": "Timed out waiting for speech", + }, + ) + self.handle_finished() + return None + + async def _stream_packets() -> AsyncIterable[bytes]: + try: + async for chunk in self._segment_audio(segmenter, chunk_buffer): + yield chunk + except asyncio.TimeoutError: + self.handle_event( + VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, + { + "code": "speech-timeout", + "message": "No speech detected", + }, + ) + self.handle_finished() + + return _stream_packets async def run_pipeline( self, + conversation_id: str | None, + use_vad: bool = False, pipeline_timeout: float = 30.0, ) -> None: """Run the Voice Assistant pipeline.""" + + tts_audio_output = ( + "raw" if self.device_info.voice_assistant_version >= 2 else "mp3" + ) + + if use_vad: + stt_stream = await self._iterate_packets_with_vad(pipeline_timeout) + # Error or timeout occurred and was handled already + if stt_stream is None: + return + else: + stt_stream = self._iterate_packets + + _LOGGER.debug("Starting pipeline") try: - tts_audio_output = ( - "raw" if self.device_info.voice_assistant_version >= 2 else "mp3" - ) async with async_timeout.timeout(pipeline_timeout): await async_pipeline_from_audio_stream( self.hass, @@ -203,10 +325,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), - stt_stream=self._iterate_packets(), + stt_stream=stt_stream(), pipeline_id=pipeline_select.get_chosen_pipeline( self.hass, DOMAIN, self.device_info.mac_address ), + conversation_id=conversation_id, tts_audio_output=tts_audio_output, ) @@ -215,6 +338,13 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): _LOGGER.debug("Pipeline finished") except asyncio.TimeoutError: + self.handle_event( + VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, + { + "code": "pipeline-timeout", + "message": "Pipeline timeout", + }, + ) _LOGGER.warning("Pipeline timeout") finally: self.handle_finished() diff --git a/requirements_all.txt b/requirements_all.txt index c15ae28cf24..aa99c0725a5 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -159,7 +159,7 @@ aioecowitt==2023.5.0 aioemonitor==1.0.5 # homeassistant.components.esphome -aioesphomeapi==13.7.5 +aioesphomeapi==13.9.0 # homeassistant.components.flo aioflo==2021.11.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 2f15bcea2b5..6ed7a309aa2 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -149,7 +149,7 @@ aioecowitt==2023.5.0 aioemonitor==1.0.5 # homeassistant.components.esphome -aioesphomeapi==13.7.5 +aioesphomeapi==13.9.0 # homeassistant.components.flo aioflo==2021.11.0 diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index 4c01c2f7c5a..23f140587c7 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -132,70 +132,51 @@ async def mock_dashboard(hass): @pytest.fixture -async def mock_voice_assistant_v1_entry( +async def mock_voice_assistant_entry( hass: HomeAssistant, mock_client, ) -> MockConfigEntry: """Set up an ESPHome entry with voice assistant.""" - entry = MockConfigEntry( - domain=DOMAIN, - data={ - CONF_HOST: "test.local", - CONF_PORT: 6053, - CONF_PASSWORD: "", - }, - ) - entry.add_to_hass(hass) - device_info = DeviceInfo( - name="test", - friendly_name="Test", - voice_assistant_version=1, - mac_address="11:22:33:44:55:aa", - esphome_version="1.0.0", - ) + async def _mock_voice_assistant_entry(version: int): + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_HOST: "test.local", + CONF_PORT: 6053, + CONF_PASSWORD: "", + }, + ) + entry.add_to_hass(hass) - mock_client.device_info = AsyncMock(return_value=device_info) - mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock()) + device_info = DeviceInfo( + name="test", + friendly_name="Test", + voice_assistant_version=version, + mac_address="11:22:33:44:55:aa", + esphome_version="1.0.0", + ) - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() - await hass.async_block_till_done() - await hass.async_block_till_done() + mock_client.device_info = AsyncMock(return_value=device_info) + mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock()) - return entry + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + await hass.async_block_till_done() + await hass.async_block_till_done() + + return entry + + return _mock_voice_assistant_entry @pytest.fixture -async def mock_voice_assistant_v2_entry( - hass: HomeAssistant, - mock_client, -) -> MockConfigEntry: +async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfigEntry: """Set up an ESPHome entry with voice assistant.""" - entry = MockConfigEntry( - domain=DOMAIN, - data={ - CONF_HOST: "test.local", - CONF_PORT: 6053, - CONF_PASSWORD: "", - }, - ) - entry.add_to_hass(hass) + return await mock_voice_assistant_entry(version=1) - device_info = DeviceInfo( - name="test", - friendly_name="Test", - voice_assistant_version=2, - mac_address="11:22:33:44:55:aa", - esphome_version="1.0.0", - ) - mock_client.device_info = AsyncMock(return_value=device_info) - mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock()) - - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() - await hass.async_block_till_done() - await hass.async_block_till_done() - - return entry +@pytest.fixture +async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry: + """Set up an ESPHome entry with voice assistant.""" + return await mock_voice_assistant_entry(version=2) diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py index fed83f8ab10..f8c2d62d095 100644 --- a/tests/components/esphome/test_voice_assistant.py +++ b/tests/components/esphome/test_voice_assistant.py @@ -19,43 +19,47 @@ _TEST_OUTPUT_TEXT = "This is an output test" _TEST_OUTPUT_URL = "output.mp3" _TEST_MEDIA_ID = "12345" +_ONE_SECOND = 16000 * 2 # 16Khz 16-bit + + +@pytest.fixture +def voice_assistant_udp_server( + hass: HomeAssistant, +) -> VoiceAssistantUDPServer: + """Return the UDP server factory.""" + + def _voice_assistant_udp_server(entry): + entry_data = DomainData.get(hass).get_entry_data(entry) + + server: VoiceAssistantUDPServer = None + + def handle_finished(): + nonlocal server + assert server is not None + server.close() + + server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished) + return server + + return _voice_assistant_udp_server + @pytest.fixture def voice_assistant_udp_server_v1( - hass: HomeAssistant, + voice_assistant_udp_server, mock_voice_assistant_v1_entry, ) -> VoiceAssistantUDPServer: """Return the UDP server.""" - entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_v1_entry) - - server: VoiceAssistantUDPServer = None - - def handle_finished(): - nonlocal server - assert server is not None - server.close() - - server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished) - return server + return voice_assistant_udp_server(entry=mock_voice_assistant_v1_entry) @pytest.fixture def voice_assistant_udp_server_v2( - hass: HomeAssistant, + voice_assistant_udp_server, mock_voice_assistant_v2_entry, ) -> VoiceAssistantUDPServer: """Return the UDP server.""" - entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_v2_entry) - - server: VoiceAssistantUDPServer = None - - def handle_finished(): - nonlocal server - assert server is not None - server.close() - - server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished) - return server + return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry) async def test_pipeline_events( @@ -117,7 +121,7 @@ async def test_pipeline_events( ): voice_assistant_udp_server_v1.transport = Mock() - await voice_assistant_udp_server_v1.run_pipeline() + await voice_assistant_udp_server_v1.run_pipeline(conversation_id=None) async def test_udp_server( @@ -335,3 +339,136 @@ async def test_send_tts( await voice_assistant_udp_server_v2._tts_done.wait() voice_assistant_udp_server_v2.transport.sendto.assert_called() + + +async def test_speech_detection( + hass: HomeAssistant, + voice_assistant_udp_server_v2: VoiceAssistantUDPServer, +) -> None: + """Test the UDP server queues incoming data.""" + + def is_speech(self, chunk, sample_rate): + """Anything non-zero is speech.""" + return sum(chunk) > 0 + + async def async_pipeline_from_audio_stream(*args, **kwargs): + stt_stream = kwargs["stt_stream"] + event_callback = kwargs["event_callback"] + async for _chunk in stt_stream: + pass + + # Test empty data + event_callback( + PipelineEvent( + type=PipelineEventType.STT_END, + data={"stt_output": {"text": _TEST_INPUT_TEXT}}, + ) + ) + + with patch( + "webrtcvad.Vad.is_speech", + new=is_speech, + ), patch( + "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ): + voice_assistant_udp_server_v2.started = True + + voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND)) + voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2)) + voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2)) + voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND)) + + await voice_assistant_udp_server_v2.run_pipeline( + conversation_id=None, use_vad=True, pipeline_timeout=1.0 + ) + + +async def test_no_speech( + hass: HomeAssistant, + voice_assistant_udp_server_v2: VoiceAssistantUDPServer, +) -> None: + """Test there is no speech.""" + + def is_speech(self, chunk, sample_rate): + """Anything non-zero is speech.""" + return sum(chunk) > 0 + + def handle_event( + event_type: esphome.VoiceAssistantEventType, data: dict[str, str] | None + ) -> None: + assert event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_ERROR + assert data is not None + assert data["code"] == "speech-timeout" + + voice_assistant_udp_server_v2.handle_event = handle_event + + with patch( + "webrtcvad.Vad.is_speech", + new=is_speech, + ): + voice_assistant_udp_server_v2.started = True + + voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND)) + + await voice_assistant_udp_server_v2.run_pipeline( + conversation_id=None, use_vad=True, pipeline_timeout=1.0 + ) + + +async def test_speech_timeout( + hass: HomeAssistant, + voice_assistant_udp_server_v2: VoiceAssistantUDPServer, +) -> None: + """Test when speech was detected, but the pipeline times out.""" + + def is_speech(self, chunk, sample_rate): + """Anything non-zero is speech.""" + return sum(chunk) > 255 + + async def async_pipeline_from_audio_stream(*args, **kwargs): + stt_stream = kwargs["stt_stream"] + async for _chunk in stt_stream: + # Stream will end when VAD detects end of "speech" + pass + + async def segment_audio(*args, **kwargs): + raise asyncio.TimeoutError() + async for chunk in []: + yield chunk + + with patch( + "webrtcvad.Vad.is_speech", + new=is_speech, + ), patch( + "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), patch( + "homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._segment_audio", + new=segment_audio, + ): + voice_assistant_udp_server_v2.started = True + + voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2))) + + await voice_assistant_udp_server_v2.run_pipeline( + conversation_id=None, use_vad=True, pipeline_timeout=1.0 + ) + + +async def test_cancelled( + hass: HomeAssistant, + voice_assistant_udp_server_v2: VoiceAssistantUDPServer, +) -> None: + """Test when the server is stopped while waiting for speech.""" + + voice_assistant_udp_server_v2.started = True + + voice_assistant_udp_server_v2.queue.put_nowait(b"") + + await voice_assistant_udp_server_v2.run_pipeline( + conversation_id=None, use_vad=True, pipeline_timeout=1.0 + ) + + # No events should be sent if cancelled while waiting for speech + voice_assistant_udp_server_v2.handle_event.assert_not_called()