ESPHome Wake Word support (#98544)

* ESPHome Wake Word support

* Remove all vad code from esphome integration

* Catch exception when no wake word provider found

* Remove import

* Remove esphome vad tests

* Add tests

* More tests
This commit is contained in:
Jesse Hills 2023-08-22 04:13:02 +12:00 committed by GitHub
parent c86565b9bc
commit a42d975c49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 225 deletions

View File

@ -18,7 +18,6 @@ from aioesphomeapi import (
UserServiceArgType,
VoiceAssistantEventType,
)
from aioesphomeapi.model import VoiceAssistantCommandFlag
from awesomeversion import AwesomeVersion
import voluptuous as vol
@ -320,7 +319,7 @@ class ESPHomeManager:
self.voice_assistant_udp_server = None
async def _handle_pipeline_start(
self, conversation_id: str, use_vad: int
self, conversation_id: str, flags: int
) -> int | None:
"""Start a voice assistant pipeline."""
if self.voice_assistant_udp_server is not None:
@ -340,8 +339,7 @@ class ESPHomeManager:
voice_assistant_udp_server.run_pipeline(
device_id=self.device_id,
conversation_id=conversation_id or None,
use_vad=VoiceAssistantCommandFlag(use_vad)
== VoiceAssistantCommandFlag.USE_VAD,
flags=flags,
),
"esphome.voice_assistant_udp_server.run_pipeline",
)

View File

@ -2,26 +2,23 @@
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import AsyncIterable, Callable, MutableSequence, Sequence
from collections.abc import AsyncIterable, Callable
import logging
import socket
from typing import cast
from aioesphomeapi import VoiceAssistantEventType
from aioesphomeapi import VoiceAssistantCommandFlag, VoiceAssistantEventType
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineNotFound,
PipelineStage,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.vad import (
VadSensitivity,
VoiceCommandSegmenter,
)
from homeassistant.components.assist_pipeline.error import WakeWordDetectionError
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.core import Context, HomeAssistant, callback
@ -47,6 +44,8 @@ _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
}
)
@ -183,121 +182,33 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
)
else:
self._tts_done.set()
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
assert event.data is not None
if not event.data["wake_word_output"]:
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
data_to_send = {
"code": "no_wake_word",
"message": "No wake word detected",
}
error = True
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert event.data is not None
data_to_send = {
"code": event.data["code"],
"message": event.data["message"],
}
self._tts_done.set()
error = True
self.handle_event(event_type, data_to_send)
if error:
self._tts_done.set()
self.handle_finished()
async def _wait_for_speech(
self,
segmenter: VoiceCommandSegmenter,
chunk_buffer: MutableSequence[bytes],
) -> bool:
"""Buffer audio chunks until speech is detected.
Raises asyncio.TimeoutError if no audio data is retrievable from the queue (device stops sending packets / networking issue).
Returns True if speech was detected
Returns False if the connection was stopped gracefully (b"" put onto the queue).
"""
# Timeout if no audio comes in for a while.
async with asyncio.timeout(self.audio_timeout):
chunk = await self.queue.get()
while chunk:
segmenter.process(chunk)
# Buffer the data we have taken from the queue
chunk_buffer.append(chunk)
if segmenter.in_command:
return True
async with asyncio.timeout(self.audio_timeout):
chunk = await self.queue.get()
# If chunk is falsey, `stop()` was called
return False
async def _segment_audio(
self,
segmenter: VoiceCommandSegmenter,
chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished.
Raises asyncio.TimeoutError if no audio data is retrievable from the queue.
"""
# Buffered chunks first
for buffered_chunk in chunk_buffer:
yield buffered_chunk
# Timeout if no audio comes in for a while.
async with asyncio.timeout(self.audio_timeout):
chunk = await self.queue.get()
while chunk:
if not segmenter.process(chunk):
# Voice command is finished
break
yield chunk
async with asyncio.timeout(self.audio_timeout):
chunk = await self.queue.get()
async def _iterate_packets_with_vad(
self, pipeline_timeout: float, silence_seconds: float
) -> Callable[[], AsyncIterable[bytes]] | None:
segmenter = VoiceCommandSegmenter(silence_seconds=silence_seconds)
chunk_buffer: deque[bytes] = deque(maxlen=100)
try:
async with asyncio.timeout(pipeline_timeout):
speech_detected = await self._wait_for_speech(segmenter, chunk_buffer)
if not speech_detected:
_LOGGER.debug(
"Device stopped sending audio before speech was detected"
)
self.handle_finished()
return None
except asyncio.TimeoutError:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "speech-timeout",
"message": "Timed out waiting for speech",
},
)
self.handle_finished()
return None
async def _stream_packets() -> AsyncIterable[bytes]:
try:
async for chunk in self._segment_audio(segmenter, chunk_buffer):
yield chunk
except asyncio.TimeoutError:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "speech-timeout",
"message": "No speech detected",
},
)
self.handle_finished()
return _stream_packets
async def run_pipeline(
self,
device_id: str,
conversation_id: str | None,
use_vad: bool = False,
flags: int = 0,
pipeline_timeout: float = 30.0,
) -> None:
"""Run the Voice Assistant pipeline."""
@ -306,24 +217,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
)
if use_vad:
stt_stream = await self._iterate_packets_with_vad(
pipeline_timeout,
silence_seconds=VadSensitivity.to_seconds(
pipeline_select.get_vad_sensitivity(
self.hass,
DOMAIN,
self.device_info.mac_address,
)
),
)
# Error or timeout occurred and was handled already
if stt_stream is None:
return
else:
stt_stream = self._iterate_packets
_LOGGER.debug("Starting pipeline")
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
start_stage = PipelineStage.WAKE_WORD
else:
start_stage = PipelineStage.STT
try:
async with asyncio.timeout(pipeline_timeout):
await async_pipeline_from_audio_stream(
@ -338,13 +236,14 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream(),
stt_stream=self._iterate_packets(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.device_info.mac_address
),
conversation_id=conversation_id,
device_id=device_id,
tts_audio_output=tts_audio_output,
start_stage=start_stage,
)
# Block until TTS is done sending
@ -356,11 +255,23 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "pipeline not found",
"message": "Selected pipeline timeout",
"message": "Selected pipeline not found",
},
)
_LOGGER.warning("Pipeline not found")
except WakeWordDetectionError as e:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": e.code,
"message": e.message,
},
)
_LOGGER.warning("No Wake word provider found")
except asyncio.TimeoutError:
if self.stopped:
# The pipeline was stopped gracefully
return
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
@ -397,7 +308,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.transport.sendto(chunk, self.remote_addr)
await asyncio.sleep(
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.99
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9
)
sample_offset += samples_in_chunk

View File

@ -7,7 +7,13 @@ from unittest.mock import Mock, patch
from aioesphomeapi import VoiceAssistantEventType
import pytest
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
PipelineNotFound,
PipelineStage,
)
from homeassistant.components.assist_pipeline.error import WakeWordDetectionError
from homeassistant.components.esphome import DomainData
from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer
from homeassistant.core import HomeAssistant
@ -71,6 +77,13 @@ async def test_pipeline_events(
event_callback = kwargs["event_callback"]
event_callback(
PipelineEvent(
type=PipelineEventType.WAKE_WORD_END,
data={"wake_word_output": {}},
)
)
# Fake events
event_callback(
PipelineEvent(
@ -112,6 +125,8 @@ async def test_pipeline_events(
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert data is not None
assert data["url"] == _TEST_OUTPUT_URL
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
assert data is None
voice_assistant_udp_server_v1.handle_event = handle_event
@ -343,134 +358,90 @@ async def test_send_tts(
voice_assistant_udp_server_v2.transport.sendto.assert_called()
async def test_speech_detection(
async def test_wake_word(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test the UDP server queues incoming data."""
"""Test that the pipeline is set to start with Wake word."""
def is_speech(self, chunk, sample_rate):
"""Anything non-zero is speech."""
return sum(chunk) > 0
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
event_callback = kwargs["event_callback"]
async for _chunk in stt_stream:
pass
# Test empty data
event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": _TEST_INPUT_TEXT}},
)
)
async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs):
assert start_stage == PipelineStage.WAKE_WORD
with patch(
"webrtcvad.Vad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
voice_assistant_udp_server_v2.started = True
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2))
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2))
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
voice_assistant_udp_server_v2.transport = Mock()
await voice_assistant_udp_server_v2.run_pipeline(
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
device_id="mock-device-id",
conversation_id=None,
flags=2,
pipeline_timeout=1,
)
async def test_no_speech(
async def test_wake_word_exception(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test there is no speech."""
def is_speech(self, chunk, sample_rate):
"""Anything non-zero is speech."""
return sum(chunk) > 0
def handle_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
assert event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
assert data is not None
assert data["code"] == "speech-timeout"
voice_assistant_udp_server_v2.handle_event = handle_event
with patch(
"webrtcvad.Vad.is_speech",
new=is_speech,
):
voice_assistant_udp_server_v2.started = True
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
await voice_assistant_udp_server_v2.run_pipeline(
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
)
async def test_speech_timeout(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test when speech was detected, but the pipeline times out."""
def is_speech(self, chunk, sample_rate):
"""Anything non-zero is speech."""
return sum(chunk) > 255
"""Test that the pipeline is set to start with Wake word."""
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
async for _chunk in stt_stream:
# Stream will end when VAD detects end of "speech"
pass
async def segment_audio(*args, **kwargs):
raise asyncio.TimeoutError()
async for chunk in []:
yield chunk
raise WakeWordDetectionError("pipeline-not-found", "Pipeline not found")
with patch(
"webrtcvad.Vad.is_speech",
new=is_speech,
), patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._segment_audio",
new=segment_audio,
):
voice_assistant_udp_server_v2.started = True
voice_assistant_udp_server_v2.transport = Mock()
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2)))
def handle_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert data is not None
assert data["code"] == "pipeline-not-found"
assert data["message"] == "Pipeline not found"
voice_assistant_udp_server_v2.handle_event = handle_event
await voice_assistant_udp_server_v2.run_pipeline(
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
device_id="mock-device-id",
conversation_id=None,
flags=2,
pipeline_timeout=1,
)
async def test_cancelled(
async def test_pipeline_timeout(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test when the server is stopped while waiting for speech."""
"""Test that the pipeline is set to start with Wake word."""
voice_assistant_udp_server_v2.started = True
async def async_pipeline_from_audio_stream(*args, **kwargs):
raise PipelineNotFound("not-found", "Pipeline not found")
voice_assistant_udp_server_v2.queue.put_nowait(b"")
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
):
voice_assistant_udp_server_v2.transport = Mock()
await voice_assistant_udp_server_v2.run_pipeline(
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
)
def handle_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None
) -> None:
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert data is not None
assert data["code"] == "pipeline not found"
assert data["message"] == "Selected pipeline not found"
# No events should be sent if cancelled while waiting for speech
voice_assistant_udp_server_v2.handle_event.assert_not_called()
voice_assistant_udp_server_v2.handle_event = handle_event
await voice_assistant_udp_server_v2.run_pipeline(
device_id="mock-device-id",
conversation_id=None,
flags=2,
pipeline_timeout=1,
)