mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
ESPHome Wake Word support (#98544)
* ESPHome Wake Word support * Remove all vad code from esphome integration * Catch exception when no wake word provider found * Remove import * Remove esphome vad tests * Add tests * More tests
This commit is contained in:
parent
c86565b9bc
commit
a42d975c49
@ -18,7 +18,6 @@ from aioesphomeapi import (
|
|||||||
UserServiceArgType,
|
UserServiceArgType,
|
||||||
VoiceAssistantEventType,
|
VoiceAssistantEventType,
|
||||||
)
|
)
|
||||||
from aioesphomeapi.model import VoiceAssistantCommandFlag
|
|
||||||
from awesomeversion import AwesomeVersion
|
from awesomeversion import AwesomeVersion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -320,7 +319,7 @@ class ESPHomeManager:
|
|||||||
self.voice_assistant_udp_server = None
|
self.voice_assistant_udp_server = None
|
||||||
|
|
||||||
async def _handle_pipeline_start(
|
async def _handle_pipeline_start(
|
||||||
self, conversation_id: str, use_vad: int
|
self, conversation_id: str, flags: int
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
"""Start a voice assistant pipeline."""
|
"""Start a voice assistant pipeline."""
|
||||||
if self.voice_assistant_udp_server is not None:
|
if self.voice_assistant_udp_server is not None:
|
||||||
@ -340,8 +339,7 @@ class ESPHomeManager:
|
|||||||
voice_assistant_udp_server.run_pipeline(
|
voice_assistant_udp_server.run_pipeline(
|
||||||
device_id=self.device_id,
|
device_id=self.device_id,
|
||||||
conversation_id=conversation_id or None,
|
conversation_id=conversation_id or None,
|
||||||
use_vad=VoiceAssistantCommandFlag(use_vad)
|
flags=flags,
|
||||||
== VoiceAssistantCommandFlag.USE_VAD,
|
|
||||||
),
|
),
|
||||||
"esphome.voice_assistant_udp_server.run_pipeline",
|
"esphome.voice_assistant_udp_server.run_pipeline",
|
||||||
)
|
)
|
||||||
|
@ -2,26 +2,23 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque
|
from collections.abc import AsyncIterable, Callable
|
||||||
from collections.abc import AsyncIterable, Callable, MutableSequence, Sequence
|
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from aioesphomeapi import VoiceAssistantEventType
|
from aioesphomeapi import VoiceAssistantCommandFlag, VoiceAssistantEventType
|
||||||
|
|
||||||
from homeassistant.components import stt, tts
|
from homeassistant.components import stt, tts
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
PipelineNotFound,
|
PipelineNotFound,
|
||||||
|
PipelineStage,
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
select as pipeline_select,
|
select as pipeline_select,
|
||||||
)
|
)
|
||||||
from homeassistant.components.assist_pipeline.vad import (
|
from homeassistant.components.assist_pipeline.error import WakeWordDetectionError
|
||||||
VadSensitivity,
|
|
||||||
VoiceCommandSegmenter,
|
|
||||||
)
|
|
||||||
from homeassistant.components.media_player import async_process_play_media_url
|
from homeassistant.components.media_player import async_process_play_media_url
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
|
|
||||||
@ -47,6 +44,8 @@ _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
|
|||||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
|
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -183,121 +182,33 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._tts_done.set()
|
self._tts_done.set()
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||||
|
assert event.data is not None
|
||||||
|
if not event.data["wake_word_output"]:
|
||||||
|
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
|
||||||
|
data_to_send = {
|
||||||
|
"code": "no_wake_word",
|
||||||
|
"message": "No wake word detected",
|
||||||
|
}
|
||||||
|
error = True
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||||
assert event.data is not None
|
assert event.data is not None
|
||||||
data_to_send = {
|
data_to_send = {
|
||||||
"code": event.data["code"],
|
"code": event.data["code"],
|
||||||
"message": event.data["message"],
|
"message": event.data["message"],
|
||||||
}
|
}
|
||||||
self._tts_done.set()
|
|
||||||
error = True
|
error = True
|
||||||
|
|
||||||
self.handle_event(event_type, data_to_send)
|
self.handle_event(event_type, data_to_send)
|
||||||
if error:
|
if error:
|
||||||
|
self._tts_done.set()
|
||||||
self.handle_finished()
|
self.handle_finished()
|
||||||
|
|
||||||
async def _wait_for_speech(
|
|
||||||
self,
|
|
||||||
segmenter: VoiceCommandSegmenter,
|
|
||||||
chunk_buffer: MutableSequence[bytes],
|
|
||||||
) -> bool:
|
|
||||||
"""Buffer audio chunks until speech is detected.
|
|
||||||
|
|
||||||
Raises asyncio.TimeoutError if no audio data is retrievable from the queue (device stops sending packets / networking issue).
|
|
||||||
|
|
||||||
Returns True if speech was detected
|
|
||||||
Returns False if the connection was stopped gracefully (b"" put onto the queue).
|
|
||||||
"""
|
|
||||||
# Timeout if no audio comes in for a while.
|
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
|
||||||
chunk = await self.queue.get()
|
|
||||||
|
|
||||||
while chunk:
|
|
||||||
segmenter.process(chunk)
|
|
||||||
# Buffer the data we have taken from the queue
|
|
||||||
chunk_buffer.append(chunk)
|
|
||||||
if segmenter.in_command:
|
|
||||||
return True
|
|
||||||
|
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
|
||||||
chunk = await self.queue.get()
|
|
||||||
|
|
||||||
# If chunk is falsey, `stop()` was called
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _segment_audio(
|
|
||||||
self,
|
|
||||||
segmenter: VoiceCommandSegmenter,
|
|
||||||
chunk_buffer: Sequence[bytes],
|
|
||||||
) -> AsyncIterable[bytes]:
|
|
||||||
"""Yield audio chunks until voice command has finished.
|
|
||||||
|
|
||||||
Raises asyncio.TimeoutError if no audio data is retrievable from the queue.
|
|
||||||
"""
|
|
||||||
# Buffered chunks first
|
|
||||||
for buffered_chunk in chunk_buffer:
|
|
||||||
yield buffered_chunk
|
|
||||||
|
|
||||||
# Timeout if no audio comes in for a while.
|
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
|
||||||
chunk = await self.queue.get()
|
|
||||||
|
|
||||||
while chunk:
|
|
||||||
if not segmenter.process(chunk):
|
|
||||||
# Voice command is finished
|
|
||||||
break
|
|
||||||
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async with asyncio.timeout(self.audio_timeout):
|
|
||||||
chunk = await self.queue.get()
|
|
||||||
|
|
||||||
async def _iterate_packets_with_vad(
|
|
||||||
self, pipeline_timeout: float, silence_seconds: float
|
|
||||||
) -> Callable[[], AsyncIterable[bytes]] | None:
|
|
||||||
segmenter = VoiceCommandSegmenter(silence_seconds=silence_seconds)
|
|
||||||
chunk_buffer: deque[bytes] = deque(maxlen=100)
|
|
||||||
try:
|
|
||||||
async with asyncio.timeout(pipeline_timeout):
|
|
||||||
speech_detected = await self._wait_for_speech(segmenter, chunk_buffer)
|
|
||||||
if not speech_detected:
|
|
||||||
_LOGGER.debug(
|
|
||||||
"Device stopped sending audio before speech was detected"
|
|
||||||
)
|
|
||||||
self.handle_finished()
|
|
||||||
return None
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.handle_event(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
|
||||||
{
|
|
||||||
"code": "speech-timeout",
|
|
||||||
"message": "Timed out waiting for speech",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.handle_finished()
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _stream_packets() -> AsyncIterable[bytes]:
|
|
||||||
try:
|
|
||||||
async for chunk in self._segment_audio(segmenter, chunk_buffer):
|
|
||||||
yield chunk
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
self.handle_event(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
|
||||||
{
|
|
||||||
"code": "speech-timeout",
|
|
||||||
"message": "No speech detected",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
self.handle_finished()
|
|
||||||
|
|
||||||
return _stream_packets
|
|
||||||
|
|
||||||
async def run_pipeline(
|
async def run_pipeline(
|
||||||
self,
|
self,
|
||||||
device_id: str,
|
device_id: str,
|
||||||
conversation_id: str | None,
|
conversation_id: str | None,
|
||||||
use_vad: bool = False,
|
flags: int = 0,
|
||||||
pipeline_timeout: float = 30.0,
|
pipeline_timeout: float = 30.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Run the Voice Assistant pipeline."""
|
"""Run the Voice Assistant pipeline."""
|
||||||
@ -306,24 +217,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||||||
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
|
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_vad:
|
|
||||||
stt_stream = await self._iterate_packets_with_vad(
|
|
||||||
pipeline_timeout,
|
|
||||||
silence_seconds=VadSensitivity.to_seconds(
|
|
||||||
pipeline_select.get_vad_sensitivity(
|
|
||||||
self.hass,
|
|
||||||
DOMAIN,
|
|
||||||
self.device_info.mac_address,
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# Error or timeout occurred and was handled already
|
|
||||||
if stt_stream is None:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
stt_stream = self._iterate_packets
|
|
||||||
|
|
||||||
_LOGGER.debug("Starting pipeline")
|
_LOGGER.debug("Starting pipeline")
|
||||||
|
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||||
|
start_stage = PipelineStage.WAKE_WORD
|
||||||
|
else:
|
||||||
|
start_stage = PipelineStage.STT
|
||||||
try:
|
try:
|
||||||
async with asyncio.timeout(pipeline_timeout):
|
async with asyncio.timeout(pipeline_timeout):
|
||||||
await async_pipeline_from_audio_stream(
|
await async_pipeline_from_audio_stream(
|
||||||
@ -338,13 +236,14 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
),
|
),
|
||||||
stt_stream=stt_stream(),
|
stt_stream=self._iterate_packets(),
|
||||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||||
self.hass, DOMAIN, self.device_info.mac_address
|
self.hass, DOMAIN, self.device_info.mac_address
|
||||||
),
|
),
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
tts_audio_output=tts_audio_output,
|
tts_audio_output=tts_audio_output,
|
||||||
|
start_stage=start_stage,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Block until TTS is done sending
|
# Block until TTS is done sending
|
||||||
@ -356,11 +255,23 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||||
{
|
{
|
||||||
"code": "pipeline not found",
|
"code": "pipeline not found",
|
||||||
"message": "Selected pipeline timeout",
|
"message": "Selected pipeline not found",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
_LOGGER.warning("Pipeline not found")
|
_LOGGER.warning("Pipeline not found")
|
||||||
|
except WakeWordDetectionError as e:
|
||||||
|
self.handle_event(
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||||
|
{
|
||||||
|
"code": e.code,
|
||||||
|
"message": e.message,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
_LOGGER.warning("No Wake word provider found")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
if self.stopped:
|
||||||
|
# The pipeline was stopped gracefully
|
||||||
|
return
|
||||||
self.handle_event(
|
self.handle_event(
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||||
{
|
{
|
||||||
@ -397,7 +308,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||||||
|
|
||||||
self.transport.sendto(chunk, self.remote_addr)
|
self.transport.sendto(chunk, self.remote_addr)
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.99
|
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9
|
||||||
)
|
)
|
||||||
|
|
||||||
sample_offset += samples_in_chunk
|
sample_offset += samples_in_chunk
|
||||||
|
@ -7,7 +7,13 @@ from unittest.mock import Mock, patch
|
|||||||
from aioesphomeapi import VoiceAssistantEventType
|
from aioesphomeapi import VoiceAssistantEventType
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineNotFound,
|
||||||
|
PipelineStage,
|
||||||
|
)
|
||||||
|
from homeassistant.components.assist_pipeline.error import WakeWordDetectionError
|
||||||
from homeassistant.components.esphome import DomainData
|
from homeassistant.components.esphome import DomainData
|
||||||
from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer
|
from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -71,6 +77,13 @@ async def test_pipeline_events(
|
|||||||
|
|
||||||
event_callback = kwargs["event_callback"]
|
event_callback = kwargs["event_callback"]
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.WAKE_WORD_END,
|
||||||
|
data={"wake_word_output": {}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Fake events
|
# Fake events
|
||||||
event_callback(
|
event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
@ -112,6 +125,8 @@ async def test_pipeline_events(
|
|||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||||
assert data is not None
|
assert data is not None
|
||||||
assert data["url"] == _TEST_OUTPUT_URL
|
assert data["url"] == _TEST_OUTPUT_URL
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||||
|
assert data is None
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.handle_event = handle_event
|
voice_assistant_udp_server_v1.handle_event = handle_event
|
||||||
|
|
||||||
@ -343,134 +358,90 @@ async def test_send_tts(
|
|||||||
voice_assistant_udp_server_v2.transport.sendto.assert_called()
|
voice_assistant_udp_server_v2.transport.sendto.assert_called()
|
||||||
|
|
||||||
|
|
||||||
async def test_speech_detection(
|
async def test_wake_word(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server queues incoming data."""
|
"""Test that the pipeline is set to start with Wake word."""
|
||||||
|
|
||||||
def is_speech(self, chunk, sample_rate):
|
async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs):
|
||||||
"""Anything non-zero is speech."""
|
assert start_stage == PipelineStage.WAKE_WORD
|
||||||
return sum(chunk) > 0
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
||||||
stt_stream = kwargs["stt_stream"]
|
|
||||||
event_callback = kwargs["event_callback"]
|
|
||||||
async for _chunk in stt_stream:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Test empty data
|
|
||||||
event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.STT_END,
|
|
||||||
data={"stt_output": {"text": _TEST_INPUT_TEXT}},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"webrtcvad.Vad.is_speech",
|
|
||||||
new=is_speech,
|
|
||||||
), patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.started = True
|
voice_assistant_udp_server_v2.transport = Mock()
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2))
|
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * _ONE_SECOND * 2))
|
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
await voice_assistant_udp_server_v2.run_pipeline(
|
||||||
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
device_id="mock-device-id",
|
||||||
|
conversation_id=None,
|
||||||
|
flags=2,
|
||||||
|
pipeline_timeout=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_no_speech(
|
async def test_wake_word_exception(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test there is no speech."""
|
"""Test that the pipeline is set to start with Wake word."""
|
||||||
|
|
||||||
def is_speech(self, chunk, sample_rate):
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
"""Anything non-zero is speech."""
|
raise WakeWordDetectionError("pipeline-not-found", "Pipeline not found")
|
||||||
return sum(chunk) > 0
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
):
|
||||||
|
voice_assistant_udp_server_v2.transport = Mock()
|
||||||
|
|
||||||
def handle_event(
|
def handle_event(
|
||||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||||
) -> None:
|
) -> None:
|
||||||
assert event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
|
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||||
assert data is not None
|
assert data is not None
|
||||||
assert data["code"] == "speech-timeout"
|
assert data["code"] == "pipeline-not-found"
|
||||||
|
assert data["message"] == "Pipeline not found"
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.handle_event = handle_event
|
voice_assistant_udp_server_v2.handle_event = handle_event
|
||||||
|
|
||||||
with patch(
|
|
||||||
"webrtcvad.Vad.is_speech",
|
|
||||||
new=is_speech,
|
|
||||||
):
|
|
||||||
voice_assistant_udp_server_v2.started = True
|
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes(_ONE_SECOND))
|
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
await voice_assistant_udp_server_v2.run_pipeline(
|
||||||
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
device_id="mock-device-id",
|
||||||
|
conversation_id=None,
|
||||||
|
flags=2,
|
||||||
|
pipeline_timeout=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_speech_timeout(
|
async def test_pipeline_timeout(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test when speech was detected, but the pipeline times out."""
|
"""Test that the pipeline is set to start with Wake word."""
|
||||||
|
|
||||||
def is_speech(self, chunk, sample_rate):
|
|
||||||
"""Anything non-zero is speech."""
|
|
||||||
return sum(chunk) > 255
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
stt_stream = kwargs["stt_stream"]
|
raise PipelineNotFound("not-found", "Pipeline not found")
|
||||||
async for _chunk in stt_stream:
|
|
||||||
# Stream will end when VAD detects end of "speech"
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def segment_audio(*args, **kwargs):
|
|
||||||
raise asyncio.TimeoutError()
|
|
||||||
async for chunk in []:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"webrtcvad.Vad.is_speech",
|
|
||||||
new=is_speech,
|
|
||||||
), patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
), patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._segment_audio",
|
|
||||||
new=segment_audio,
|
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.started = True
|
voice_assistant_udp_server_v2.transport = Mock()
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(bytes([255] * (_ONE_SECOND * 2)))
|
def handle_event(
|
||||||
|
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
|
||||||
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_cancelled(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test when the server is stopped while waiting for speech."""
|
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||||
|
assert data is not None
|
||||||
|
assert data["code"] == "pipeline not found"
|
||||||
|
assert data["message"] == "Selected pipeline not found"
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.started = True
|
voice_assistant_udp_server_v2.handle_event = handle_event
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.queue.put_nowait(b"")
|
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
await voice_assistant_udp_server_v2.run_pipeline(
|
||||||
device_id="", conversation_id=None, use_vad=True, pipeline_timeout=1.0
|
device_id="mock-device-id",
|
||||||
|
conversation_id=None,
|
||||||
|
flags=2,
|
||||||
|
pipeline_timeout=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# No events should be sent if cancelled while waiting for speech
|
|
||||||
voice_assistant_udp_server_v2.handle_event.assert_not_called()
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user