From a42d975c490a315ab42a7a9754627bbf95eff200 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Tue, 22 Aug 2023 04:13:02 +1200 Subject: [PATCH] ESPHome Wake Word support (#98544) * ESPHome Wake Word support * Remove all vad code from esphome integration * Catch exception when no wake word provider found * Remove import * Remove esphome vad tests * Add tests * More tests --- homeassistant/components/esphome/manager.py | 6 +- .../components/esphome/voice_assistant.py | 163 ++++-------------- .../esphome/test_voice_assistant.py | 161 +++++++---------- 3 files changed, 105 insertions(+), 225 deletions(-) diff --git a/homeassistant/components/esphome/manager.py b/homeassistant/components/esphome/manager.py index 35939dc9b1f..fb3e0a1e79a 100644 --- a/homeassistant/components/esphome/manager.py +++ b/homeassistant/components/esphome/manager.py @@ -18,7 +18,6 @@ from aioesphomeapi import ( UserServiceArgType, VoiceAssistantEventType, ) -from aioesphomeapi.model import VoiceAssistantCommandFlag from awesomeversion import AwesomeVersion import voluptuous as vol @@ -320,7 +319,7 @@ class ESPHomeManager: self.voice_assistant_udp_server = None async def _handle_pipeline_start( - self, conversation_id: str, use_vad: int + self, conversation_id: str, flags: int ) -> int | None: """Start a voice assistant pipeline.""" if self.voice_assistant_udp_server is not None: @@ -340,8 +339,7 @@ class ESPHomeManager: voice_assistant_udp_server.run_pipeline( device_id=self.device_id, conversation_id=conversation_id or None, - use_vad=VoiceAssistantCommandFlag(use_vad) - == VoiceAssistantCommandFlag.USE_VAD, + flags=flags, ), "esphome.voice_assistant_udp_server.run_pipeline", ) diff --git a/homeassistant/components/esphome/voice_assistant.py b/homeassistant/components/esphome/voice_assistant.py index f870f9e42f7..a9397eda935 100644 --- a/homeassistant/components/esphome/voice_assistant.py +++ b/homeassistant/components/esphome/voice_assistant.py @@ -2,26 +2,23 @@ from __future__ import annotations import asyncio -from collections import deque -from collections.abc import AsyncIterable, Callable, MutableSequence, Sequence +from collections.abc import AsyncIterable, Callable import logging import socket from typing import cast -from aioesphomeapi import VoiceAssistantEventType +from aioesphomeapi import VoiceAssistantCommandFlag, VoiceAssistantEventType from homeassistant.components import stt, tts from homeassistant.components.assist_pipeline import ( PipelineEvent, PipelineEventType, PipelineNotFound, + PipelineStage, async_pipeline_from_audio_stream, select as pipeline_select, ) -from homeassistant.components.assist_pipeline.vad import ( - VadSensitivity, - VoiceCommandSegmenter, -) +from homeassistant.components.assist_pipeline.error import WakeWordDetectionError from homeassistant.components.media_player import async_process_play_media_url from homeassistant.core import Context, HomeAssistant, callback @@ -47,6 +44,8 @@ _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[ VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END, VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START, VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END, + VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START, + VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END, } ) @@ -183,121 +182,33 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): ) else: self._tts_done.set() + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: + assert event.data is not None + if not event.data["wake_word_output"]: + event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR + data_to_send = { + "code": "no_wake_word", + "message": "No wake word detected", + } + error = True elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: assert event.data is not None data_to_send = { "code": event.data["code"], "message": event.data["message"], } - self._tts_done.set() error = True self.handle_event(event_type, data_to_send) if error: + self._tts_done.set() self.handle_finished() - async def _wait_for_speech( - self, - segmenter: VoiceCommandSegmenter, - chunk_buffer: MutableSequence[bytes], - ) -> bool: - """Buffer audio chunks until speech is detected. - - Raises asyncio.TimeoutError if no audio data is retrievable from the queue (device stops sending packets / networking issue). - - Returns True if speech was detected - Returns False if the connection was stopped gracefully (b"" put onto the queue). - """ - # Timeout if no audio comes in for a while. - async with asyncio.timeout(self.audio_timeout): - chunk = await self.queue.get() - - while chunk: - segmenter.process(chunk) - # Buffer the data we have taken from the queue - chunk_buffer.append(chunk) - if segmenter.in_command: - return True - - async with asyncio.timeout(self.audio_timeout): - chunk = await self.queue.get() - - # If chunk is falsey, `stop()` was called - return False - - async def _segment_audio( - self, - segmenter: VoiceCommandSegmenter, - chunk_buffer: Sequence[bytes], - ) -> AsyncIterable[bytes]: - """Yield audio chunks until voice command has finished. - - Raises asyncio.TimeoutError if no audio data is retrievable from the queue. - """ - # Buffered chunks first - for buffered_chunk in chunk_buffer: - yield buffered_chunk - - # Timeout if no audio comes in for a while. - async with asyncio.timeout(self.audio_timeout): - chunk = await self.queue.get() - - while chunk: - if not segmenter.process(chunk): - # Voice command is finished - break - - yield chunk - - async with asyncio.timeout(self.audio_timeout): - chunk = await self.queue.get() - - async def _iterate_packets_with_vad( - self, pipeline_timeout: float, silence_seconds: float - ) -> Callable[[], AsyncIterable[bytes]] | None: - segmenter = VoiceCommandSegmenter(silence_seconds=silence_seconds) - chunk_buffer: deque[bytes] = deque(maxlen=100) - try: - async with asyncio.timeout(pipeline_timeout): - speech_detected = await self._wait_for_speech(segmenter, chunk_buffer) - if not speech_detected: - _LOGGER.debug( - "Device stopped sending audio before speech was detected" - ) - self.handle_finished() - return None - except asyncio.TimeoutError: - self.handle_event( - VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, - { - "code": "speech-timeout", - "message": "Timed out waiting for speech", - }, - ) - self.handle_finished() - return None - - async def _stream_packets() -> AsyncIterable[bytes]: - try: - async for chunk in self._segment_audio(segmenter, chunk_buffer): - yield chunk - except asyncio.TimeoutError: - self.handle_event( - VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, - { - "code": "speech-timeout", - "message": "No speech detected", - }, - ) - self.handle_finished() - - return _stream_packets - async def run_pipeline( self, device_id: str, conversation_id: str | None, - use_vad: bool = False, + flags: int = 0, pipeline_timeout: float = 30.0, ) -> None: """Run the Voice Assistant pipeline.""" @@ -306,24 +217,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): "raw" if self.device_info.voice_assistant_version >= 2 else "mp3" ) - if use_vad: - stt_stream = await self._iterate_packets_with_vad( - pipeline_timeout, - silence_seconds=VadSensitivity.to_seconds( - pipeline_select.get_vad_sensitivity( - self.hass, - DOMAIN, - self.device_info.mac_address, - ) - ), - ) - # Error or timeout occurred and was handled already - if stt_stream is None: - return - else: - stt_stream = self._iterate_packets - _LOGGER.debug("Starting pipeline") + if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD: + start_stage = PipelineStage.WAKE_WORD + else: + start_stage = PipelineStage.STT try: async with asyncio.timeout(pipeline_timeout): await async_pipeline_from_audio_stream( @@ -338,13 +236,14 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, channel=stt.AudioChannels.CHANNEL_MONO, ), - stt_stream=stt_stream(), + stt_stream=self._iterate_packets(), pipeline_id=pipeline_select.get_chosen_pipeline( self.hass, DOMAIN, self.device_info.mac_address ), conversation_id=conversation_id, device_id=device_id, tts_audio_output=tts_audio_output, + start_stage=start_stage, ) # Block until TTS is done sending @@ -356,11 +255,23 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, { "code": "pipeline not found", - "message": "Selected pipeline timeout", + "message": "Selected pipeline not found", }, ) _LOGGER.warning("Pipeline not found") + except WakeWordDetectionError as e: + self.handle_event( + VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, + { + "code": e.code, + "message": e.message, + }, + ) + _LOGGER.warning("No Wake word provider found") except asyncio.TimeoutError: + if self.stopped: + # The pipeline was stopped gracefully + return self.handle_event( VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, { @@ -397,7 +308,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): self.transport.sendto(chunk, self.remote_addr) await asyncio.sleep( - samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.99 + samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9 ) sample_offset += samples_in_chunk diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py index d6562651f0b..b7ce5670441 100644 --- a/tests/components/esphome/test_voice_assistant.py +++ b/tests/components/esphome/test_voice_assistant.py @@ -7,7 +7,13 @@ from unittest.mock import Mock, patch from aioesphomeapi import VoiceAssistantEventType import pytest -from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType +from homeassistant.components.assist_pipeline import ( + PipelineEvent, + PipelineEventType, + PipelineNotFound, + PipelineStage, +) +from homeassistant.components.assist_pipeline.error import WakeWordDetectionError from homeassistant.components.esphome import DomainData from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer from homeassistant.core import HomeAssistant @@ -71,6 +77,13 @@ async def test_pipeline_events( event_callback = kwargs["event_callback"] + event_callback( + PipelineEvent( + type=PipelineEventType.WAKE_WORD_END, + data={"wake_word_output": {}}, + ) + ) + # Fake events event_callback( PipelineEvent( @@ -112,6 +125,8 @@ async def test_pipeline_events( elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: assert data is not None assert data["url"] == _TEST_OUTPUT_URL + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: + assert data is None voice_assistant_udp_server_v1.handle_event = handle_event @@ -343,134 +358,90 @@ async def test_send_tts( voice_assistant_udp_server_v2.transport.sendto.assert_called() -async def test_speech_detection( +async def test_wake_word( hass: HomeAssistant, voice_assistant_udp_server_v2: VoiceAssistantUDPServer, ) -> None: - """Test the UDP server queues incoming data.""" + """Test that the pipeline is set to start with Wake word.""" - def is_speech(self, chunk, sample_rate): - """Anything non-zero is speech.""" - return sum(chunk) > 0 - - async def async_pipeline_from_audio_stream(*args, **kwargs): - stt_stream = kwargs["stt_stream"] - event_callback = kwargs["event_callback"] - async for _chunk in stt_stream: - pass - - # Test empty data - event_callback( - PipelineEvent( - type=PipelineEventType.STT_END, - data={"stt_output": {"text": _TEST_INPUT_TEXT}}, - ) - ) + async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs): + assert start_stage == PipelineStage.WAKE_WORD with patch( - "webrtcvad.Vad.is_speech", - new=is_speech, - ), patch( "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ): - voice_assistant_udp_server_v2.started = True - - voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND)) - voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2)) - voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2)) - voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND)) + voice_assistant_udp_server_v2.transport = Mock() await voice_assistant_udp_server_v2.run_pipeline( - device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0 + device_id="mock-device-id", + conversation_id=None, + flags=2, + pipeline_timeout=1, ) -async def test_no_speech( +async def test_wake_word_exception( hass: HomeAssistant, voice_assistant_udp_server_v2: VoiceAssistantUDPServer, ) -> None: - """Test there is no speech.""" - - def is_speech(self, chunk, sample_rate): - """Anything non-zero is speech.""" - return sum(chunk) > 0 - - def handle_event( - event_type: VoiceAssistantEventType, data: dict[str, str] | None - ) -> None: - assert event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR - assert data is not None - assert data["code"] == "speech-timeout" - - voice_assistant_udp_server_v2.handle_event = handle_event - - with patch( - "webrtcvad.Vad.is_speech", - new=is_speech, - ): - voice_assistant_udp_server_v2.started = True - - voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND)) - - await voice_assistant_udp_server_v2.run_pipeline( - device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0 - ) - - -async def test_speech_timeout( - hass: HomeAssistant, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, -) -> None: - """Test when speech was detected, but the pipeline times out.""" - - def is_speech(self, chunk, sample_rate): - """Anything non-zero is speech.""" - return sum(chunk) > 255 + """Test that the pipeline is set to start with Wake word.""" async def async_pipeline_from_audio_stream(*args, **kwargs): - stt_stream = kwargs["stt_stream"] - async for _chunk in stt_stream: - # Stream will end when VAD detects end of "speech" - pass - - async def segment_audio(*args, **kwargs): - raise asyncio.TimeoutError() - async for chunk in []: - yield chunk + raise WakeWordDetectionError("pipeline-not-found", "Pipeline not found") with patch( - "webrtcvad.Vad.is_speech", - new=is_speech, - ), patch( "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, - ), patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._segment_audio", - new=segment_audio, ): - voice_assistant_udp_server_v2.started = True + voice_assistant_udp_server_v2.transport = Mock() - voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2))) + def handle_event( + event_type: VoiceAssistantEventType, data: dict[str, str] | None + ) -> None: + if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: + assert data is not None + assert data["code"] == "pipeline-not-found" + assert data["message"] == "Pipeline not found" + + voice_assistant_udp_server_v2.handle_event = handle_event await voice_assistant_udp_server_v2.run_pipeline( - device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0 + device_id="mock-device-id", + conversation_id=None, + flags=2, + pipeline_timeout=1, ) -async def test_cancelled( +async def test_pipeline_timeout( hass: HomeAssistant, voice_assistant_udp_server_v2: VoiceAssistantUDPServer, ) -> None: - """Test when the server is stopped while waiting for speech.""" + """Test that the pipeline is set to start with Wake word.""" - voice_assistant_udp_server_v2.started = True + async def async_pipeline_from_audio_stream(*args, **kwargs): + raise PipelineNotFound("not-found", "Pipeline not found") - voice_assistant_udp_server_v2.queue.put_nowait(b"") + with patch( + "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ): + voice_assistant_udp_server_v2.transport = Mock() - await voice_assistant_udp_server_v2.run_pipeline( - device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0 - ) + def handle_event( + event_type: VoiceAssistantEventType, data: dict[str, str] | None + ) -> None: + if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: + assert data is not None + assert data["code"] == "pipeline not found" + assert data["message"] == "Selected pipeline not found" - # No events should be sent if cancelled while waiting for speech - voice_assistant_udp_server_v2.handle_event.assert_not_called() + voice_assistant_udp_server_v2.handle_event = handle_event + + await voice_assistant_udp_server_v2.run_pipeline( + device_id="mock-device-id", + conversation_id=None, + flags=2, + pipeline_timeout=1, + )