mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 22:27:07 +00:00
Migrate ESPHome to assist satellite (#125383)
* Migrate ESPHome to assist satellite * Address comments
This commit is contained in:
parent
b6d45a5a07
commit
f126a6024e
504
homeassistant/components/esphome/assist_satellite.py
Normal file
504
homeassistant/components/esphome/assist_satellite.py
Normal file
@ -0,0 +1,504 @@
|
|||||||
|
"""Support for assist satellites in ESPHome."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncIterable
|
||||||
|
from functools import partial
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
from typing import Any, cast
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from aioesphomeapi import (
|
||||||
|
VoiceAssistantAudioSettings,
|
||||||
|
VoiceAssistantCommandFlag,
|
||||||
|
VoiceAssistantEventType,
|
||||||
|
VoiceAssistantFeature,
|
||||||
|
VoiceAssistantTimerEventType,
|
||||||
|
)
|
||||||
|
|
||||||
|
from homeassistant.components import assist_satellite, tts
|
||||||
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
PipelineEvent,
|
||||||
|
PipelineEventType,
|
||||||
|
PipelineStage,
|
||||||
|
)
|
||||||
|
from homeassistant.components.intent import async_register_timer_handler
|
||||||
|
from homeassistant.components.intent.timers import TimerEventType, TimerInfo
|
||||||
|
from homeassistant.components.media_player import async_process_play_media_url
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import EntityCategory, Platform
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import entity_registry as er
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .entity import EsphomeAssistEntity
|
||||||
|
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
|
||||||
|
from .enum_mapper import EsphomeEnumMapper
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
|
||||||
|
VoiceAssistantEventType, PipelineEventType
|
||||||
|
] = EsphomeEnumMapper(
|
||||||
|
{
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
|
||||||
|
EsphomeEnumMapper(
|
||||||
|
{
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
entry: ESPHomeConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up Assist satellite entity."""
|
||||||
|
entry_data = entry.runtime_data
|
||||||
|
assert entry_data.device_info is not None
|
||||||
|
if entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
entry_data.api_version
|
||||||
|
):
|
||||||
|
async_add_entities(
|
||||||
|
[
|
||||||
|
EsphomeAssistSatellite(entry, entry_data),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EsphomeAssistSatellite(
|
||||||
|
EsphomeAssistEntity, assist_satellite.AssistSatelliteEntity
|
||||||
|
):
|
||||||
|
"""Satellite running ESPHome."""
|
||||||
|
|
||||||
|
entity_description = assist_satellite.AssistSatelliteEntityDescription(
|
||||||
|
key="assist_satellite",
|
||||||
|
translation_key="assist_satellite",
|
||||||
|
entity_category=EntityCategory.CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
entry_data: RuntimeEntryData,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize satellite."""
|
||||||
|
super().__init__(entry_data)
|
||||||
|
|
||||||
|
self.config_entry = config_entry
|
||||||
|
self.entry_data = entry_data
|
||||||
|
self.cli = self.entry_data.client
|
||||||
|
|
||||||
|
self._is_running: bool = True
|
||||||
|
self._pipeline_task: asyncio.Task | None = None
|
||||||
|
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||||
|
self._tts_streaming_task: asyncio.Task | None = None
|
||||||
|
self._udp_server: VoiceAssistantUDPServer | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pipeline_entity_id(self) -> str | None:
|
||||||
|
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||||
|
assert self.entry_data.device_info is not None
|
||||||
|
ent_reg = er.async_get(self.hass)
|
||||||
|
return ent_reg.async_get_entity_id(
|
||||||
|
Platform.SELECT,
|
||||||
|
DOMAIN,
|
||||||
|
f"{self.entry_data.device_info.mac_address}-pipeline",
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vad_sensitivity_entity_id(self) -> str | None:
|
||||||
|
"""Return the entity ID of the VAD sensitivity to use for the next conversation."""
|
||||||
|
assert self.entry_data.device_info is not None
|
||||||
|
ent_reg = er.async_get(self.hass)
|
||||||
|
return ent_reg.async_get_entity_id(
|
||||||
|
Platform.SELECT,
|
||||||
|
DOMAIN,
|
||||||
|
f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""Run when entity about to be added to hass."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
|
||||||
|
assert self.entry_data.device_info is not None
|
||||||
|
feature_flags = (
|
||||||
|
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
self.entry_data.api_version
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if feature_flags & VoiceAssistantFeature.API_AUDIO:
|
||||||
|
# TCP audio
|
||||||
|
self.entry_data.disconnect_callbacks.add(
|
||||||
|
self.cli.subscribe_voice_assistant(
|
||||||
|
handle_start=self.handle_pipeline_start,
|
||||||
|
handle_stop=self.handle_pipeline_stop,
|
||||||
|
handle_audio=self.handle_audio,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# UDP audio
|
||||||
|
self.entry_data.disconnect_callbacks.add(
|
||||||
|
self.cli.subscribe_voice_assistant(
|
||||||
|
handle_start=self.handle_pipeline_start,
|
||||||
|
handle_stop=self.handle_pipeline_stop,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if feature_flags & VoiceAssistantFeature.TIMERS:
|
||||||
|
# Device supports timers
|
||||||
|
assert (self.registry_entry is not None) and (
|
||||||
|
self.registry_entry.device_id is not None
|
||||||
|
)
|
||||||
|
self.entry_data.disconnect_callbacks.add(
|
||||||
|
async_register_timer_handler(
|
||||||
|
self.hass, self.registry_entry.device_id, self.handle_timer_event
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
|
"""Run when entity will be removed from hass."""
|
||||||
|
await super().async_will_remove_from_hass()
|
||||||
|
|
||||||
|
self._is_running = False
|
||||||
|
self._stop_pipeline()
|
||||||
|
|
||||||
|
def on_pipeline_event(self, event: PipelineEvent) -> None:
|
||||||
|
"""Handle pipeline events."""
|
||||||
|
try:
|
||||||
|
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
|
||||||
|
except KeyError:
|
||||||
|
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
|
||||||
|
return
|
||||||
|
|
||||||
|
data_to_send: dict[str, Any] = {}
|
||||||
|
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
|
||||||
|
self.entry_data.async_set_assist_pipeline_state(True)
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
||||||
|
assert event.data is not None
|
||||||
|
data_to_send = {"text": event.data["stt_output"]["text"]}
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
|
||||||
|
assert event.data is not None
|
||||||
|
data_to_send = {
|
||||||
|
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
|
||||||
|
}
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
||||||
|
assert event.data is not None
|
||||||
|
data_to_send = {"text": event.data["tts_input"]}
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||||
|
assert event.data is not None
|
||||||
|
if tts_output := event.data["tts_output"]:
|
||||||
|
path = tts_output["url"]
|
||||||
|
url = async_process_play_media_url(self.hass, path)
|
||||||
|
data_to_send = {"url": url}
|
||||||
|
|
||||||
|
assert self.entry_data.device_info is not None
|
||||||
|
feature_flags = (
|
||||||
|
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
self.entry_data.api_version
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if feature_flags & VoiceAssistantFeature.SPEAKER:
|
||||||
|
media_id = tts_output["media_id"]
|
||||||
|
self._tts_streaming_task = (
|
||||||
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._stream_tts_audio(media_id),
|
||||||
|
"esphome_voice_assistant_tts",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||||
|
assert event.data is not None
|
||||||
|
if not event.data["wake_word_output"]:
|
||||||
|
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
|
||||||
|
data_to_send = {
|
||||||
|
"code": "no_wake_word",
|
||||||
|
"message": "No wake word detected",
|
||||||
|
}
|
||||||
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||||
|
assert event.data is not None
|
||||||
|
data_to_send = {
|
||||||
|
"code": event.data["code"],
|
||||||
|
"message": event.data["message"],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.cli.send_voice_assistant_event(event_type, data_to_send)
|
||||||
|
|
||||||
|
async def handle_pipeline_start(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
flags: int,
|
||||||
|
audio_settings: VoiceAssistantAudioSettings,
|
||||||
|
wake_word_phrase: str | None,
|
||||||
|
) -> int | None:
|
||||||
|
"""Handle pipeline run request."""
|
||||||
|
# Clear audio queue
|
||||||
|
while not self._audio_queue.empty():
|
||||||
|
await self._audio_queue.get()
|
||||||
|
|
||||||
|
if self._tts_streaming_task is not None:
|
||||||
|
# Cancel current TTS response
|
||||||
|
self._tts_streaming_task.cancel()
|
||||||
|
self._tts_streaming_task = None
|
||||||
|
|
||||||
|
# API or UDP output audio
|
||||||
|
port: int = 0
|
||||||
|
assert self.entry_data.device_info is not None
|
||||||
|
feature_flags = (
|
||||||
|
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
self.entry_data.api_version
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if (feature_flags & VoiceAssistantFeature.SPEAKER) and not (
|
||||||
|
feature_flags & VoiceAssistantFeature.API_AUDIO
|
||||||
|
):
|
||||||
|
port = await self._start_udp_server()
|
||||||
|
_LOGGER.debug("Started UDP server on port %s", port)
|
||||||
|
|
||||||
|
# Device triggered pipeline (wake word, etc.)
|
||||||
|
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||||
|
start_stage = PipelineStage.WAKE_WORD
|
||||||
|
else:
|
||||||
|
start_stage = PipelineStage.STT
|
||||||
|
|
||||||
|
end_stage = PipelineStage.TTS
|
||||||
|
|
||||||
|
# Run the pipeline
|
||||||
|
_LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage)
|
||||||
|
self.entry_data.async_set_assist_pipeline_state(True)
|
||||||
|
self._pipeline_task = self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self.async_accept_pipeline_from_satellite(
|
||||||
|
audio_stream=self._wrap_audio_stream(),
|
||||||
|
start_stage=start_stage,
|
||||||
|
end_stage=end_stage,
|
||||||
|
wake_word_phrase=wake_word_phrase,
|
||||||
|
),
|
||||||
|
"esphome_assist_satellite_pipeline",
|
||||||
|
)
|
||||||
|
self._pipeline_task.add_done_callback(
|
||||||
|
lambda _future: self.handle_pipeline_finished()
|
||||||
|
)
|
||||||
|
|
||||||
|
return port
|
||||||
|
|
||||||
|
async def handle_audio(self, data: bytes) -> None:
|
||||||
|
"""Handle incoming audio chunk from API."""
|
||||||
|
self._audio_queue.put_nowait(data)
|
||||||
|
|
||||||
|
async def handle_pipeline_stop(self) -> None:
|
||||||
|
"""Handle request for pipeline to stop."""
|
||||||
|
self._stop_pipeline()
|
||||||
|
|
||||||
|
def handle_pipeline_finished(self) -> None:
|
||||||
|
"""Handle when pipeline has finished running."""
|
||||||
|
self.entry_data.async_set_assist_pipeline_state(False)
|
||||||
|
self._stop_udp_server()
|
||||||
|
_LOGGER.debug("Pipeline finished")
|
||||||
|
|
||||||
|
def handle_timer_event(
|
||||||
|
self, event_type: TimerEventType, timer_info: TimerInfo
|
||||||
|
) -> None:
|
||||||
|
"""Handle timer events."""
|
||||||
|
try:
|
||||||
|
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
|
||||||
|
except KeyError:
|
||||||
|
_LOGGER.debug("Received unknown timer event type: %s", event_type)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.cli.send_voice_assistant_timer_event(
|
||||||
|
native_event_type,
|
||||||
|
timer_info.id,
|
||||||
|
timer_info.name,
|
||||||
|
timer_info.created_seconds,
|
||||||
|
timer_info.seconds_left,
|
||||||
|
timer_info.is_active,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _stream_tts_audio(
|
||||||
|
self,
|
||||||
|
media_id: str,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
sample_width: int = 2,
|
||||||
|
sample_channels: int = 1,
|
||||||
|
samples_per_chunk: int = 512,
|
||||||
|
) -> None:
|
||||||
|
"""Stream TTS audio chunks to device via API or UDP."""
|
||||||
|
self.cli.send_voice_assistant_event(
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not self._is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
extension, data = await tts.async_get_media_source_audio(
|
||||||
|
self.hass,
|
||||||
|
media_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if extension != "wav":
|
||||||
|
_LOGGER.error("Only WAV audio can be streamed, got %s", extension)
|
||||||
|
return
|
||||||
|
|
||||||
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||||
|
if (
|
||||||
|
(wav_file.getframerate() != sample_rate)
|
||||||
|
or (wav_file.getsampwidth() != sample_width)
|
||||||
|
or (wav_file.getnchannels() != sample_channels)
|
||||||
|
):
|
||||||
|
_LOGGER.error("Can only stream 16Khz 16-bit mono WAV")
|
||||||
|
return
|
||||||
|
|
||||||
|
_LOGGER.debug("Streaming %s audio samples", wav_file.getnframes())
|
||||||
|
|
||||||
|
while self._is_running:
|
||||||
|
chunk = wav_file.readframes(samples_per_chunk)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
if self._udp_server is not None:
|
||||||
|
self._udp_server.send_audio_bytes(chunk)
|
||||||
|
else:
|
||||||
|
self.cli.send_voice_assistant_audio(chunk)
|
||||||
|
|
||||||
|
# Wait for 90% of the duration of the audio that was
|
||||||
|
# sent for it to be played. This will overrun the
|
||||||
|
# device's buffer for very long audio, so using a media
|
||||||
|
# player is preferred.
|
||||||
|
samples_in_chunk = len(chunk) // (sample_width * sample_channels)
|
||||||
|
seconds_in_chunk = samples_in_chunk / sample_rate
|
||||||
|
await asyncio.sleep(seconds_in_chunk * 0.9)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return # Don't trigger state change
|
||||||
|
finally:
|
||||||
|
self.cli.send_voice_assistant_event(
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
# State change
|
||||||
|
self.tts_response_finished()
|
||||||
|
|
||||||
|
async def _wrap_audio_stream(self) -> AsyncIterable[bytes]:
|
||||||
|
"""Yield audio chunks from the queue until None."""
|
||||||
|
while True:
|
||||||
|
chunk = await self._audio_queue.get()
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
def _stop_pipeline(self) -> None:
|
||||||
|
"""Request pipeline to be stopped."""
|
||||||
|
self._audio_queue.put_nowait(None)
|
||||||
|
_LOGGER.debug("Requested pipeline stop")
|
||||||
|
|
||||||
|
async def _start_udp_server(self) -> int:
|
||||||
|
"""Start a UDP server on a random free port."""
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
sock.setblocking(False)
|
||||||
|
sock.bind(("", 0)) # random free port
|
||||||
|
|
||||||
|
(
|
||||||
|
_transport,
|
||||||
|
protocol,
|
||||||
|
) = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||||
|
partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(protocol, VoiceAssistantUDPServer)
|
||||||
|
self._udp_server = protocol
|
||||||
|
|
||||||
|
# Return port
|
||||||
|
return cast(int, sock.getsockname()[1])
|
||||||
|
|
||||||
|
def _stop_udp_server(self) -> None:
|
||||||
|
"""Stop the UDP server if it's running."""
|
||||||
|
if self._udp_server is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._udp_server.close()
|
||||||
|
finally:
|
||||||
|
self._udp_server = None
|
||||||
|
|
||||||
|
_LOGGER.debug("Stopped UDP server")
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
|
"""Receive UDP packets and forward them to the audio queue."""
|
||||||
|
|
||||||
|
transport: asyncio.DatagramTransport | None = None
|
||||||
|
remote_addr: tuple[str, int] | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any
|
||||||
|
) -> None:
|
||||||
|
"""Initialize protocol."""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._audio_queue = audio_queue
|
||||||
|
|
||||||
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
|
"""Store transport for later use."""
|
||||||
|
self.transport = cast(asyncio.DatagramTransport, transport)
|
||||||
|
|
||||||
|
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||||
|
"""Handle incoming UDP packet."""
|
||||||
|
if self.remote_addr is None:
|
||||||
|
self.remote_addr = addr
|
||||||
|
|
||||||
|
self._audio_queue.put_nowait(data)
|
||||||
|
|
||||||
|
def error_received(self, exc: Exception) -> None:
|
||||||
|
"""Handle when a send or receive operation raises an OSError.
|
||||||
|
|
||||||
|
(Other than BlockingIOError or InterruptedError.)
|
||||||
|
"""
|
||||||
|
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
|
||||||
|
|
||||||
|
# Stop pipeline
|
||||||
|
self._audio_queue.put_nowait(None)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the receiver."""
|
||||||
|
if self.transport is not None:
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
self.remote_addr = None
|
||||||
|
|
||||||
|
def send_audio_bytes(self, data: bytes) -> None:
|
||||||
|
"""Send bytes to the device via UDP."""
|
||||||
|
if self.transport is None:
|
||||||
|
_LOGGER.error("No transport to send audio to")
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.remote_addr is None:
|
||||||
|
_LOGGER.error("No address to send audio to")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.transport.sendto(data, self.remote_addr)
|
@ -20,19 +20,17 @@ from aioesphomeapi import (
|
|||||||
RequiresEncryptionAPIError,
|
RequiresEncryptionAPIError,
|
||||||
UserService,
|
UserService,
|
||||||
UserServiceArgType,
|
UserServiceArgType,
|
||||||
VoiceAssistantAudioSettings,
|
|
||||||
VoiceAssistantFeature,
|
|
||||||
)
|
)
|
||||||
from awesomeversion import AwesomeVersion
|
from awesomeversion import AwesomeVersion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components import tag, zeroconf
|
from homeassistant.components import tag, zeroconf
|
||||||
from homeassistant.components.intent import async_register_timer_handler
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
ATTR_DEVICE_ID,
|
ATTR_DEVICE_ID,
|
||||||
CONF_MODE,
|
CONF_MODE,
|
||||||
EVENT_HOMEASSISTANT_CLOSE,
|
EVENT_HOMEASSISTANT_CLOSE,
|
||||||
EVENT_LOGGING_CHANGED,
|
EVENT_LOGGING_CHANGED,
|
||||||
|
Platform,
|
||||||
)
|
)
|
||||||
from homeassistant.core import (
|
from homeassistant.core import (
|
||||||
Event,
|
Event,
|
||||||
@ -73,12 +71,6 @@ from .domain_data import DomainData
|
|||||||
|
|
||||||
# Import config flow so that it's added to the registry
|
# Import config flow so that it's added to the registry
|
||||||
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
|
from .entry_data import ESPHomeConfigEntry, RuntimeEntryData
|
||||||
from .voice_assistant import (
|
|
||||||
VoiceAssistantAPIPipeline,
|
|
||||||
VoiceAssistantPipeline,
|
|
||||||
VoiceAssistantUDPPipeline,
|
|
||||||
handle_timer_event,
|
|
||||||
)
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -149,7 +141,6 @@ class ESPHomeManager:
|
|||||||
"cli",
|
"cli",
|
||||||
"device_id",
|
"device_id",
|
||||||
"domain_data",
|
"domain_data",
|
||||||
"voice_assistant_pipeline",
|
|
||||||
"reconnect_logic",
|
"reconnect_logic",
|
||||||
"zeroconf_instance",
|
"zeroconf_instance",
|
||||||
"entry_data",
|
"entry_data",
|
||||||
@ -173,7 +164,6 @@ class ESPHomeManager:
|
|||||||
self.cli = cli
|
self.cli = cli
|
||||||
self.device_id: str | None = None
|
self.device_id: str | None = None
|
||||||
self.domain_data = domain_data
|
self.domain_data = domain_data
|
||||||
self.voice_assistant_pipeline: VoiceAssistantPipeline | None = None
|
|
||||||
self.reconnect_logic: ReconnectLogic | None = None
|
self.reconnect_logic: ReconnectLogic | None = None
|
||||||
self.zeroconf_instance = zeroconf_instance
|
self.zeroconf_instance = zeroconf_instance
|
||||||
self.entry_data = entry.runtime_data
|
self.entry_data = entry.runtime_data
|
||||||
@ -338,77 +328,6 @@ class ESPHomeManager:
|
|||||||
entity_id, attribute, self.hass.states.get(entity_id)
|
entity_id, attribute, self.hass.states.get(entity_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_pipeline_finished(self) -> None:
|
|
||||||
self.entry_data.async_set_assist_pipeline_state(False)
|
|
||||||
|
|
||||||
if self.voice_assistant_pipeline is not None:
|
|
||||||
if isinstance(self.voice_assistant_pipeline, VoiceAssistantUDPPipeline):
|
|
||||||
self.voice_assistant_pipeline.close()
|
|
||||||
self.voice_assistant_pipeline = None
|
|
||||||
|
|
||||||
async def _handle_pipeline_start(
|
|
||||||
self,
|
|
||||||
conversation_id: str,
|
|
||||||
flags: int,
|
|
||||||
audio_settings: VoiceAssistantAudioSettings,
|
|
||||||
wake_word_phrase: str | None,
|
|
||||||
) -> int | None:
|
|
||||||
"""Start a voice assistant pipeline."""
|
|
||||||
if self.voice_assistant_pipeline is not None:
|
|
||||||
_LOGGER.warning("Previous Voice assistant pipeline was not stopped")
|
|
||||||
self.voice_assistant_pipeline.stop()
|
|
||||||
self.voice_assistant_pipeline = None
|
|
||||||
|
|
||||||
hass = self.hass
|
|
||||||
assert self.entry_data.device_info is not None
|
|
||||||
if (
|
|
||||||
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
|
||||||
self.entry_data.api_version
|
|
||||||
)
|
|
||||||
& VoiceAssistantFeature.API_AUDIO
|
|
||||||
):
|
|
||||||
self.voice_assistant_pipeline = VoiceAssistantAPIPipeline(
|
|
||||||
hass,
|
|
||||||
self.entry_data,
|
|
||||||
self.cli.send_voice_assistant_event,
|
|
||||||
self._handle_pipeline_finished,
|
|
||||||
self.cli,
|
|
||||||
)
|
|
||||||
port = 0
|
|
||||||
else:
|
|
||||||
self.voice_assistant_pipeline = VoiceAssistantUDPPipeline(
|
|
||||||
hass,
|
|
||||||
self.entry_data,
|
|
||||||
self.cli.send_voice_assistant_event,
|
|
||||||
self._handle_pipeline_finished,
|
|
||||||
)
|
|
||||||
port = await self.voice_assistant_pipeline.start_server()
|
|
||||||
|
|
||||||
assert self.device_id is not None, "Device ID must be set"
|
|
||||||
hass.async_create_background_task(
|
|
||||||
self.voice_assistant_pipeline.run_pipeline(
|
|
||||||
device_id=self.device_id,
|
|
||||||
conversation_id=conversation_id or None,
|
|
||||||
flags=flags,
|
|
||||||
audio_settings=audio_settings,
|
|
||||||
wake_word_phrase=wake_word_phrase,
|
|
||||||
),
|
|
||||||
"esphome.voice_assistant_pipeline.run_pipeline",
|
|
||||||
)
|
|
||||||
|
|
||||||
return port
|
|
||||||
|
|
||||||
async def _handle_pipeline_stop(self) -> None:
|
|
||||||
"""Stop a voice assistant pipeline."""
|
|
||||||
if self.voice_assistant_pipeline is not None:
|
|
||||||
self.voice_assistant_pipeline.stop()
|
|
||||||
|
|
||||||
async def _handle_audio(self, data: bytes) -> None:
|
|
||||||
if self.voice_assistant_pipeline is None:
|
|
||||||
return
|
|
||||||
assert isinstance(self.voice_assistant_pipeline, VoiceAssistantAPIPipeline)
|
|
||||||
self.voice_assistant_pipeline.receive_audio_bytes(data)
|
|
||||||
|
|
||||||
async def on_connect(self) -> None:
|
async def on_connect(self) -> None:
|
||||||
"""Subscribe to states and list entities on successful API login."""
|
"""Subscribe to states and list entities on successful API login."""
|
||||||
try:
|
try:
|
||||||
@ -509,29 +428,14 @@ class ESPHomeManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
flags = device_info.voice_assistant_feature_flags_compat(api_version)
|
if device_info.voice_assistant_feature_flags_compat(api_version) and (
|
||||||
if flags:
|
Platform.ASSIST_SATELLITE not in entry_data.loaded_platforms
|
||||||
if flags & VoiceAssistantFeature.API_AUDIO:
|
):
|
||||||
entry_data.disconnect_callbacks.add(
|
# Create assist satellite entity
|
||||||
cli.subscribe_voice_assistant(
|
await self.hass.config_entries.async_forward_entry_setups(
|
||||||
handle_start=self._handle_pipeline_start,
|
self.entry, [Platform.ASSIST_SATELLITE]
|
||||||
handle_stop=self._handle_pipeline_stop,
|
|
||||||
handle_audio=self._handle_audio,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
entry_data.disconnect_callbacks.add(
|
|
||||||
cli.subscribe_voice_assistant(
|
|
||||||
handle_start=self._handle_pipeline_start,
|
|
||||||
handle_stop=self._handle_pipeline_stop,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if flags & VoiceAssistantFeature.TIMERS:
|
|
||||||
entry_data.disconnect_callbacks.add(
|
|
||||||
async_register_timer_handler(
|
|
||||||
hass, self.device_id, partial(handle_timer_event, cli)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
entry_data.loaded_platforms.add(Platform.ASSIST_SATELLITE)
|
||||||
|
|
||||||
cli.subscribe_states(entry_data.async_update_state)
|
cli.subscribe_states(entry_data.async_update_state)
|
||||||
cli.subscribe_service_calls(self.async_on_service_call)
|
cli.subscribe_service_calls(self.async_on_service_call)
|
||||||
|
@ -1,479 +0,0 @@
|
|||||||
"""ESPHome voice assistant support."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections.abc import AsyncIterable, Callable
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import socket
|
|
||||||
from typing import cast
|
|
||||||
import wave
|
|
||||||
|
|
||||||
from aioesphomeapi import (
|
|
||||||
APIClient,
|
|
||||||
VoiceAssistantAudioSettings,
|
|
||||||
VoiceAssistantCommandFlag,
|
|
||||||
VoiceAssistantEventType,
|
|
||||||
VoiceAssistantFeature,
|
|
||||||
VoiceAssistantTimerEventType,
|
|
||||||
)
|
|
||||||
|
|
||||||
from homeassistant.components import stt, tts
|
|
||||||
from homeassistant.components.assist_pipeline import (
|
|
||||||
AudioSettings,
|
|
||||||
PipelineEvent,
|
|
||||||
PipelineEventType,
|
|
||||||
PipelineNotFound,
|
|
||||||
PipelineStage,
|
|
||||||
WakeWordSettings,
|
|
||||||
async_pipeline_from_audio_stream,
|
|
||||||
select as pipeline_select,
|
|
||||||
)
|
|
||||||
from homeassistant.components.assist_pipeline.error import (
|
|
||||||
WakeWordDetectionAborted,
|
|
||||||
WakeWordDetectionError,
|
|
||||||
)
|
|
||||||
from homeassistant.components.assist_pipeline.vad import VadSensitivity
|
|
||||||
from homeassistant.components.intent.timers import TimerEventType, TimerInfo
|
|
||||||
from homeassistant.components.media_player import async_process_play_media_url
|
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
|
||||||
|
|
||||||
from .const import DOMAIN
|
|
||||||
from .entry_data import RuntimeEntryData
|
|
||||||
from .enum_mapper import EsphomeEnumMapper
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
UDP_PORT = 0 # Set to 0 to let the OS pick a free random port
|
|
||||||
UDP_MAX_PACKET_SIZE = 1024
|
|
||||||
|
|
||||||
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
|
|
||||||
VoiceAssistantEventType, PipelineEventType
|
|
||||||
] = EsphomeEnumMapper(
|
|
||||||
{
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START,
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = (
|
|
||||||
EsphomeEnumMapper(
|
|
||||||
{
|
|
||||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED,
|
|
||||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED,
|
|
||||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED,
|
|
||||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VoiceAssistantPipeline:
|
|
||||||
"""Base abstract pipeline class."""
|
|
||||||
|
|
||||||
started = False
|
|
||||||
stop_requested = False
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
entry_data: RuntimeEntryData,
|
|
||||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
|
||||||
handle_finished: Callable[[], None],
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the pipeline."""
|
|
||||||
self.context = Context()
|
|
||||||
self.hass = hass
|
|
||||||
self.entry_data = entry_data
|
|
||||||
assert entry_data.device_info is not None
|
|
||||||
self.device_info = entry_data.device_info
|
|
||||||
|
|
||||||
self.queue: asyncio.Queue[bytes] = asyncio.Queue()
|
|
||||||
self.handle_event = handle_event
|
|
||||||
self.handle_finished = handle_finished
|
|
||||||
self._tts_done = asyncio.Event()
|
|
||||||
self._tts_task: asyncio.Task | None = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_running(self) -> bool:
|
|
||||||
"""True if the pipeline is started and hasn't been asked to stop."""
|
|
||||||
return self.started and (not self.stop_requested)
|
|
||||||
|
|
||||||
async def _iterate_packets(self) -> AsyncIterable[bytes]:
|
|
||||||
"""Iterate over incoming packets."""
|
|
||||||
while data := await self.queue.get():
|
|
||||||
if not self.is_running:
|
|
||||||
break
|
|
||||||
|
|
||||||
yield data
|
|
||||||
|
|
||||||
def _event_callback(self, event: PipelineEvent) -> None:
|
|
||||||
"""Handle pipeline events."""
|
|
||||||
|
|
||||||
try:
|
|
||||||
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
|
|
||||||
except KeyError:
|
|
||||||
_LOGGER.debug("Received unknown pipeline event type: %s", event.type)
|
|
||||||
return
|
|
||||||
|
|
||||||
data_to_send = None
|
|
||||||
error = False
|
|
||||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START:
|
|
||||||
self.entry_data.async_set_assist_pipeline_state(True)
|
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
|
||||||
assert event.data is not None
|
|
||||||
data_to_send = {"text": event.data["stt_output"]["text"]}
|
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
|
|
||||||
assert event.data is not None
|
|
||||||
data_to_send = {
|
|
||||||
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
|
|
||||||
}
|
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
|
||||||
assert event.data is not None
|
|
||||||
data_to_send = {"text": event.data["tts_input"]}
|
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
|
||||||
assert event.data is not None
|
|
||||||
tts_output = event.data["tts_output"]
|
|
||||||
if tts_output:
|
|
||||||
path = tts_output["url"]
|
|
||||||
url = async_process_play_media_url(self.hass, path)
|
|
||||||
data_to_send = {"url": url}
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.device_info.voice_assistant_feature_flags_compat(
|
|
||||||
self.entry_data.api_version
|
|
||||||
)
|
|
||||||
& VoiceAssistantFeature.SPEAKER
|
|
||||||
):
|
|
||||||
media_id = tts_output["media_id"]
|
|
||||||
self._tts_task = self.hass.async_create_background_task(
|
|
||||||
self._send_tts(media_id), "esphome_voice_assistant_tts"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._tts_done.set()
|
|
||||||
else:
|
|
||||||
# Empty TTS response
|
|
||||||
data_to_send = {}
|
|
||||||
self._tts_done.set()
|
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
|
||||||
assert event.data is not None
|
|
||||||
if not event.data["wake_word_output"]:
|
|
||||||
event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR
|
|
||||||
data_to_send = {
|
|
||||||
"code": "no_wake_word",
|
|
||||||
"message": "No wake word detected",
|
|
||||||
}
|
|
||||||
error = True
|
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
|
||||||
assert event.data is not None
|
|
||||||
data_to_send = {
|
|
||||||
"code": event.data["code"],
|
|
||||||
"message": event.data["message"],
|
|
||||||
}
|
|
||||||
error = True
|
|
||||||
|
|
||||||
self.handle_event(event_type, data_to_send)
|
|
||||||
if error:
|
|
||||||
self._tts_done.set()
|
|
||||||
self.handle_finished()
|
|
||||||
|
|
||||||
async def run_pipeline(
|
|
||||||
self,
|
|
||||||
device_id: str,
|
|
||||||
conversation_id: str | None,
|
|
||||||
flags: int = 0,
|
|
||||||
audio_settings: VoiceAssistantAudioSettings | None = None,
|
|
||||||
wake_word_phrase: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Run the Voice Assistant pipeline."""
|
|
||||||
if audio_settings is None or audio_settings.volume_multiplier == 0:
|
|
||||||
audio_settings = VoiceAssistantAudioSettings()
|
|
||||||
|
|
||||||
if (
|
|
||||||
self.device_info.voice_assistant_feature_flags_compat(
|
|
||||||
self.entry_data.api_version
|
|
||||||
)
|
|
||||||
& VoiceAssistantFeature.SPEAKER
|
|
||||||
):
|
|
||||||
tts_audio_output = "wav"
|
|
||||||
else:
|
|
||||||
tts_audio_output = "mp3"
|
|
||||||
|
|
||||||
_LOGGER.debug("Starting pipeline")
|
|
||||||
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
|
||||||
start_stage = PipelineStage.WAKE_WORD
|
|
||||||
else:
|
|
||||||
start_stage = PipelineStage.STT
|
|
||||||
try:
|
|
||||||
await async_pipeline_from_audio_stream(
|
|
||||||
self.hass,
|
|
||||||
context=self.context,
|
|
||||||
event_callback=self._event_callback,
|
|
||||||
stt_metadata=stt.SpeechMetadata(
|
|
||||||
language="", # set in async_pipeline_from_audio_stream
|
|
||||||
format=stt.AudioFormats.WAV,
|
|
||||||
codec=stt.AudioCodecs.PCM,
|
|
||||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
|
||||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
|
||||||
),
|
|
||||||
stt_stream=self._iterate_packets(),
|
|
||||||
pipeline_id=pipeline_select.get_chosen_pipeline(
|
|
||||||
self.hass, DOMAIN, self.device_info.mac_address
|
|
||||||
),
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
device_id=device_id,
|
|
||||||
tts_audio_output=tts_audio_output,
|
|
||||||
start_stage=start_stage,
|
|
||||||
wake_word_settings=WakeWordSettings(timeout=5),
|
|
||||||
wake_word_phrase=wake_word_phrase,
|
|
||||||
audio_settings=AudioSettings(
|
|
||||||
noise_suppression_level=audio_settings.noise_suppression_level,
|
|
||||||
auto_gain_dbfs=audio_settings.auto_gain,
|
|
||||||
volume_multiplier=audio_settings.volume_multiplier,
|
|
||||||
is_vad_enabled=bool(flags & VoiceAssistantCommandFlag.USE_VAD),
|
|
||||||
silence_seconds=VadSensitivity.to_seconds(
|
|
||||||
pipeline_select.get_vad_sensitivity(
|
|
||||||
self.hass, DOMAIN, self.device_info.mac_address
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Block until TTS is done sending
|
|
||||||
await self._tts_done.wait()
|
|
||||||
|
|
||||||
_LOGGER.debug("Pipeline finished")
|
|
||||||
except PipelineNotFound as e:
|
|
||||||
self.handle_event(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
|
||||||
{
|
|
||||||
"code": e.code,
|
|
||||||
"message": e.message,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
_LOGGER.warning("Pipeline not found")
|
|
||||||
except WakeWordDetectionAborted:
|
|
||||||
pass # Wake word detection was aborted and `handle_finished` is enough.
|
|
||||||
except WakeWordDetectionError as e:
|
|
||||||
self.handle_event(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
|
||||||
{
|
|
||||||
"code": e.code,
|
|
||||||
"message": e.message,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
self.handle_finished()
|
|
||||||
|
|
||||||
async def _send_tts(self, media_id: str) -> None:
|
|
||||||
"""Send TTS audio to device via UDP."""
|
|
||||||
# Always send stream start/end events
|
|
||||||
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not self.is_running:
|
|
||||||
return
|
|
||||||
|
|
||||||
extension, data = await tts.async_get_media_source_audio(
|
|
||||||
self.hass,
|
|
||||||
media_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if extension != "wav":
|
|
||||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
|
||||||
|
|
||||||
with io.BytesIO(data) as wav_io:
|
|
||||||
with wave.open(wav_io, "rb") as wav_file:
|
|
||||||
sample_rate = wav_file.getframerate()
|
|
||||||
sample_width = wav_file.getsampwidth()
|
|
||||||
sample_channels = wav_file.getnchannels()
|
|
||||||
|
|
||||||
if (
|
|
||||||
(sample_rate != 16000)
|
|
||||||
or (sample_width != 2)
|
|
||||||
or (sample_channels != 1)
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Expected rate/width/channels as 16000/2/1,"
|
|
||||||
" got {sample_rate}/{sample_width}/{sample_channels}}"
|
|
||||||
)
|
|
||||||
|
|
||||||
audio_bytes = wav_file.readframes(wav_file.getnframes())
|
|
||||||
|
|
||||||
audio_bytes_size = len(audio_bytes)
|
|
||||||
|
|
||||||
_LOGGER.debug("Sending %d bytes of audio", audio_bytes_size)
|
|
||||||
|
|
||||||
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
|
|
||||||
sample_offset = 0
|
|
||||||
samples_left = audio_bytes_size // bytes_per_sample
|
|
||||||
|
|
||||||
while (samples_left > 0) and self.is_running:
|
|
||||||
bytes_offset = sample_offset * bytes_per_sample
|
|
||||||
chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024]
|
|
||||||
samples_in_chunk = len(chunk) // bytes_per_sample
|
|
||||||
samples_left -= samples_in_chunk
|
|
||||||
|
|
||||||
self.send_audio_bytes(chunk)
|
|
||||||
await asyncio.sleep(
|
|
||||||
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9
|
|
||||||
)
|
|
||||||
|
|
||||||
sample_offset += samples_in_chunk
|
|
||||||
finally:
|
|
||||||
self.handle_event(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
|
|
||||||
)
|
|
||||||
self._tts_task = None
|
|
||||||
self._tts_done.set()
|
|
||||||
|
|
||||||
def send_audio_bytes(self, data: bytes) -> None:
|
|
||||||
"""Send bytes to the device."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the pipeline."""
|
|
||||||
self.queue.put_nowait(b"")
|
|
||||||
|
|
||||||
|
|
||||||
class VoiceAssistantUDPPipeline(asyncio.DatagramProtocol, VoiceAssistantPipeline):
|
|
||||||
"""Receive UDP packets and forward them to the voice assistant."""
|
|
||||||
|
|
||||||
transport: asyncio.DatagramTransport | None = None
|
|
||||||
remote_addr: tuple[str, int] | None = None
|
|
||||||
|
|
||||||
async def start_server(self) -> int:
|
|
||||||
"""Start accepting connections."""
|
|
||||||
|
|
||||||
def accept_connection() -> VoiceAssistantUDPPipeline:
|
|
||||||
"""Accept connection."""
|
|
||||||
if self.started:
|
|
||||||
raise RuntimeError("Can only start once")
|
|
||||||
if self.stop_requested:
|
|
||||||
raise RuntimeError("No longer accepting connections")
|
|
||||||
|
|
||||||
self.started = True
|
|
||||||
return self
|
|
||||||
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
sock.setblocking(False)
|
|
||||||
|
|
||||||
sock.bind(("", UDP_PORT))
|
|
||||||
|
|
||||||
await asyncio.get_running_loop().create_datagram_endpoint(
|
|
||||||
accept_connection, sock=sock
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(int, sock.getsockname()[1])
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
||||||
"""Store transport for later use."""
|
|
||||||
self.transport = cast(asyncio.DatagramTransport, transport)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
|
||||||
"""Handle incoming UDP packet."""
|
|
||||||
if not self.is_running:
|
|
||||||
return
|
|
||||||
if self.remote_addr is None:
|
|
||||||
self.remote_addr = addr
|
|
||||||
self.queue.put_nowait(data)
|
|
||||||
|
|
||||||
def error_received(self, exc: Exception) -> None:
|
|
||||||
"""Handle when a send or receive operation raises an OSError.
|
|
||||||
|
|
||||||
(Other than BlockingIOError or InterruptedError.)
|
|
||||||
"""
|
|
||||||
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
|
|
||||||
self.handle_finished()
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the receiver."""
|
|
||||||
super().stop()
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close the receiver."""
|
|
||||||
self.started = False
|
|
||||||
self.stop_requested = True
|
|
||||||
|
|
||||||
if self.transport is not None:
|
|
||||||
self.transport.close()
|
|
||||||
|
|
||||||
def send_audio_bytes(self, data: bytes) -> None:
|
|
||||||
"""Send bytes to the device via UDP."""
|
|
||||||
if self.transport is None:
|
|
||||||
_LOGGER.error("No transport to send audio to")
|
|
||||||
return
|
|
||||||
self.transport.sendto(data, self.remote_addr)
|
|
||||||
|
|
||||||
|
|
||||||
class VoiceAssistantAPIPipeline(VoiceAssistantPipeline):
|
|
||||||
"""Send audio to the voice assistant via the API."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hass: HomeAssistant,
|
|
||||||
entry_data: RuntimeEntryData,
|
|
||||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
|
||||||
handle_finished: Callable[[], None],
|
|
||||||
api_client: APIClient,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the pipeline."""
|
|
||||||
super().__init__(hass, entry_data, handle_event, handle_finished)
|
|
||||||
self.api_client = api_client
|
|
||||||
self.started = True
|
|
||||||
|
|
||||||
def send_audio_bytes(self, data: bytes) -> None:
|
|
||||||
"""Send bytes to the device via the API."""
|
|
||||||
self.api_client.send_voice_assistant_audio(data)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def receive_audio_bytes(self, data: bytes) -> None:
|
|
||||||
"""Receive audio bytes from the device."""
|
|
||||||
if not self.is_running:
|
|
||||||
return
|
|
||||||
self.queue.put_nowait(data)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the pipeline."""
|
|
||||||
super().stop()
|
|
||||||
|
|
||||||
self.started = False
|
|
||||||
self.stop_requested = True
|
|
||||||
|
|
||||||
|
|
||||||
def handle_timer_event(
|
|
||||||
api_client: APIClient, event_type: TimerEventType, timer_info: TimerInfo
|
|
||||||
) -> None:
|
|
||||||
"""Handle timer events."""
|
|
||||||
try:
|
|
||||||
native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type)
|
|
||||||
except KeyError:
|
|
||||||
_LOGGER.debug("Received unknown timer event type: %s", event_type)
|
|
||||||
return
|
|
||||||
|
|
||||||
api_client.send_voice_assistant_timer_event(
|
|
||||||
native_event_type,
|
|
||||||
timer_info.id,
|
|
||||||
timer_info.name,
|
|
||||||
timer_info.created_seconds,
|
|
||||||
timer_info.seconds_left,
|
|
||||||
timer_info.is_active,
|
|
||||||
)
|
|
@ -20,7 +20,6 @@ from aioesphomeapi import (
|
|||||||
ReconnectLogic,
|
ReconnectLogic,
|
||||||
UserService,
|
UserService,
|
||||||
VoiceAssistantAudioSettings,
|
VoiceAssistantAudioSettings,
|
||||||
VoiceAssistantEventType,
|
|
||||||
VoiceAssistantFeature,
|
VoiceAssistantFeature,
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
@ -34,11 +33,6 @@ from homeassistant.components.esphome.const import (
|
|||||||
DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS,
|
DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
from homeassistant.components.esphome.entry_data import RuntimeEntryData
|
|
||||||
from homeassistant.components.esphome.voice_assistant import (
|
|
||||||
VoiceAssistantAPIPipeline,
|
|
||||||
VoiceAssistantUDPPipeline,
|
|
||||||
)
|
|
||||||
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT
|
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
@ -625,57 +619,3 @@ async def mock_esphome_device(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _mock_device
|
return _mock_device
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_voice_assistant_api_pipeline() -> VoiceAssistantAPIPipeline:
|
|
||||||
"""Return the API Pipeline factory."""
|
|
||||||
mock_pipeline = Mock(spec=VoiceAssistantAPIPipeline)
|
|
||||||
|
|
||||||
def mock_constructor(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
entry_data: RuntimeEntryData,
|
|
||||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
|
||||||
handle_finished: Callable[[], None],
|
|
||||||
api_client: APIClient,
|
|
||||||
):
|
|
||||||
"""Fake the constructor."""
|
|
||||||
mock_pipeline.hass = hass
|
|
||||||
mock_pipeline.entry_data = entry_data
|
|
||||||
mock_pipeline.handle_event = handle_event
|
|
||||||
mock_pipeline.handle_finished = handle_finished
|
|
||||||
mock_pipeline.api_client = api_client
|
|
||||||
return mock_pipeline
|
|
||||||
|
|
||||||
mock_pipeline.side_effect = mock_constructor
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantAPIPipeline",
|
|
||||||
new=mock_pipeline,
|
|
||||||
):
|
|
||||||
yield mock_pipeline
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_voice_assistant_udp_pipeline() -> VoiceAssistantUDPPipeline:
|
|
||||||
"""Return the API Pipeline factory."""
|
|
||||||
mock_pipeline = Mock(spec=VoiceAssistantUDPPipeline)
|
|
||||||
|
|
||||||
def mock_constructor(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
entry_data: RuntimeEntryData,
|
|
||||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
|
||||||
handle_finished: Callable[[], None],
|
|
||||||
):
|
|
||||||
"""Fake the constructor."""
|
|
||||||
mock_pipeline.hass = hass
|
|
||||||
mock_pipeline.entry_data = entry_data
|
|
||||||
mock_pipeline.handle_event = handle_event
|
|
||||||
mock_pipeline.handle_finished = handle_finished
|
|
||||||
return mock_pipeline
|
|
||||||
|
|
||||||
mock_pipeline.side_effect = mock_constructor
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPPipeline",
|
|
||||||
new=mock_pipeline,
|
|
||||||
):
|
|
||||||
yield mock_pipeline
|
|
||||||
|
822
tests/components/esphome/test_assist_satellite.py
Normal file
822
tests/components/esphome/test_assist_satellite.py
Normal file
@ -0,0 +1,822 @@
|
|||||||
|
"""Test ESPHome voice assistant server."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
import io
|
||||||
|
import socket
|
||||||
|
from unittest.mock import ANY, Mock, patch
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from aioesphomeapi import (
|
||||||
|
APIClient,
|
||||||
|
EntityInfo,
|
||||||
|
EntityState,
|
||||||
|
UserService,
|
||||||
|
VoiceAssistantAudioSettings,
|
||||||
|
VoiceAssistantCommandFlag,
|
||||||
|
VoiceAssistantEventType,
|
||||||
|
VoiceAssistantFeature,
|
||||||
|
VoiceAssistantTimerEventType,
|
||||||
|
)
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.components import assist_satellite
|
||||||
|
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
|
||||||
|
from homeassistant.components.assist_satellite.entity import (
|
||||||
|
AssistSatelliteEntity,
|
||||||
|
AssistSatelliteState,
|
||||||
|
)
|
||||||
|
from homeassistant.components.esphome import DOMAIN
|
||||||
|
from homeassistant.components.esphome.assist_satellite import (
|
||||||
|
EsphomeAssistSatellite,
|
||||||
|
VoiceAssistantUDPServer,
|
||||||
|
)
|
||||||
|
from homeassistant.const import Platform
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import entity_registry as er, intent as intent_helper
|
||||||
|
import homeassistant.helpers.device_registry as dr
|
||||||
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
|
||||||
|
from .conftest import MockESPHomeDevice
|
||||||
|
|
||||||
|
|
||||||
|
def get_satellite_entity(
|
||||||
|
hass: HomeAssistant, mac_address: str
|
||||||
|
) -> EsphomeAssistSatellite | None:
|
||||||
|
"""Get the satellite entity for a device."""
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
satellite_entity_id = ent_reg.async_get_entity_id(
|
||||||
|
Platform.ASSIST_SATELLITE, DOMAIN, f"{mac_address}-assist_satellite"
|
||||||
|
)
|
||||||
|
if satellite_entity_id is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[
|
||||||
|
assist_satellite.DOMAIN
|
||||||
|
]
|
||||||
|
if (entity := component.get_entity(satellite_entity_id)) is not None:
|
||||||
|
assert isinstance(entity, EsphomeAssistSatellite)
|
||||||
|
return entity
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_wav() -> bytes:
|
||||||
|
"""Return test WAV audio."""
|
||||||
|
with io.BytesIO() as wav_io:
|
||||||
|
with wave.open(wav_io, "wb") as wav_file:
|
||||||
|
wav_file.setframerate(16000)
|
||||||
|
wav_file.setsampwidth(2)
|
||||||
|
wav_file.setnchannels(1)
|
||||||
|
wav_file.writeframes(b"test-wav")
|
||||||
|
|
||||||
|
return wav_io.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_no_satellite_without_voice_assistant(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_client: APIClient,
|
||||||
|
mock_esphome_device: Callable[
|
||||||
|
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||||
|
Awaitable[MockESPHomeDevice],
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Test that an assist satellite entity is not created if a voice assistant is not present."""
|
||||||
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
|
mock_client=mock_client,
|
||||||
|
entity_info=[],
|
||||||
|
user_service=[],
|
||||||
|
states=[],
|
||||||
|
device_info={},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# No satellite entity should be created
|
||||||
|
assert get_satellite_entity(hass, mock_device.device_info.mac_address) is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_api_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
mock_client: APIClient,
|
||||||
|
mock_esphome_device: Callable[
|
||||||
|
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||||
|
Awaitable[MockESPHomeDevice],
|
||||||
|
],
|
||||||
|
mock_wav: bytes,
|
||||||
|
) -> None:
|
||||||
|
"""Test a complete pipeline run with API audio (over the TCP connection)."""
|
||||||
|
conversation_id = "test-conversation-id"
|
||||||
|
media_url = "http://test.url"
|
||||||
|
media_id = "test-media-id"
|
||||||
|
|
||||||
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
|
mock_client=mock_client,
|
||||||
|
entity_info=[],
|
||||||
|
user_service=[],
|
||||||
|
states=[],
|
||||||
|
device_info={
|
||||||
|
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
| VoiceAssistantFeature.SPEAKER
|
||||||
|
| VoiceAssistantFeature.API_AUDIO
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
dev = device_registry.async_get_device(
|
||||||
|
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
||||||
|
)
|
||||||
|
|
||||||
|
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||||
|
assert satellite is not None
|
||||||
|
|
||||||
|
# Block TTS streaming until we're ready.
|
||||||
|
# This makes it easier to verify the order of pipeline events.
|
||||||
|
stream_tts_audio_ready = asyncio.Event()
|
||||||
|
original_stream_tts_audio = satellite._stream_tts_audio
|
||||||
|
|
||||||
|
async def _stream_tts_audio(*args, **kwargs):
|
||||||
|
await stream_tts_audio_ready.wait()
|
||||||
|
await original_stream_tts_audio(*args, **kwargs)
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||||
|
assert device_id == dev.id
|
||||||
|
|
||||||
|
stt_stream = kwargs["stt_stream"]
|
||||||
|
|
||||||
|
chunks = [chunk async for chunk in stt_stream]
|
||||||
|
|
||||||
|
# Verify test API audio
|
||||||
|
assert chunks == [b"test-mic"]
|
||||||
|
|
||||||
|
event_callback = kwargs["event_callback"]
|
||||||
|
|
||||||
|
# Test unknown event type
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type="unknown-event",
|
||||||
|
data={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.send_voice_assistant_event.assert_not_called()
|
||||||
|
|
||||||
|
# Test error event
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.ERROR,
|
||||||
|
data={"code": "test-error-code", "message": "test-error-message"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||||
|
{"code": "test-error-code", "message": "test-error-message"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wake word
|
||||||
|
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.WAKE_WORD_START,
|
||||||
|
data={
|
||||||
|
"entity_id": "test-wake-word-entity-id",
|
||||||
|
"metadata": {},
|
||||||
|
"timeout": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test no wake word detected
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.WAKE_WORD_END, data={"wake_word_output": {}}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||||
|
{"code": "no_wake_word", "message": "No wake word detected"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Correct wake word detection
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.WAKE_WORD_END,
|
||||||
|
data={"wake_word_output": {"wake_word_phrase": "test-wake-word"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
# STT
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.STT_START,
|
||||||
|
data={"engine": "test-stt-engine", "metadata": {}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
assert satellite.state == AssistSatelliteState.LISTENING_COMMAND
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.STT_END,
|
||||||
|
data={"stt_output": {"text": "test-stt-text"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END,
|
||||||
|
{"text": "test-stt-text"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Intent
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.INTENT_START,
|
||||||
|
data={
|
||||||
|
"engine": "test-intent-engine",
|
||||||
|
"language": hass.config.language,
|
||||||
|
"intent_input": "test-intent-text",
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
assert satellite.state == AssistSatelliteState.PROCESSING
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.INTENT_END,
|
||||||
|
data={"intent_output": {"conversation_id": conversation_id}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END,
|
||||||
|
{"conversation_id": conversation_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# TTS
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.TTS_START,
|
||||||
|
data={
|
||||||
|
"engine": "test-stt-engine",
|
||||||
|
"language": hass.config.language,
|
||||||
|
"voice": "test-voice",
|
||||||
|
"tts_input": "test-tts-text",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START,
|
||||||
|
{"text": "test-tts-text"},
|
||||||
|
)
|
||||||
|
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||||
|
|
||||||
|
# Should return mock_wav audio
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.TTS_END,
|
||||||
|
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
|
||||||
|
{"url": media_url},
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow TTS streaming to proceed
|
||||||
|
stream_tts_audio_ready.set()
|
||||||
|
|
||||||
|
pipeline_finished = asyncio.Event()
|
||||||
|
original_handle_pipeline_finished = satellite.handle_pipeline_finished
|
||||||
|
|
||||||
|
def handle_pipeline_finished():
|
||||||
|
original_handle_pipeline_finished()
|
||||||
|
pipeline_finished.set()
|
||||||
|
|
||||||
|
async def async_get_media_source_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
return ("wav", mock_wav)
|
||||||
|
|
||||||
|
tts_finished = asyncio.Event()
|
||||||
|
original_tts_response_finished = satellite.tts_response_finished
|
||||||
|
|
||||||
|
def tts_response_finished():
|
||||||
|
original_tts_response_finished()
|
||||||
|
tts_finished.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_get_media_source_audio",
|
||||||
|
new=async_get_media_source_audio,
|
||||||
|
),
|
||||||
|
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||||
|
patch.object(satellite, "_stream_tts_audio", _stream_tts_audio),
|
||||||
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||||
|
):
|
||||||
|
# Should be cleared at pipeline start
|
||||||
|
satellite._audio_queue.put_nowait(b"leftover-data")
|
||||||
|
|
||||||
|
# Should be cancelled at pipeline start
|
||||||
|
mock_tts_streaming_task = Mock()
|
||||||
|
satellite._tts_streaming_task = mock_tts_streaming_task
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await satellite.handle_pipeline_start(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
flags=VoiceAssistantCommandFlag.USE_WAKE_WORD,
|
||||||
|
audio_settings=VoiceAssistantAudioSettings(),
|
||||||
|
wake_word_phrase="",
|
||||||
|
)
|
||||||
|
mock_tts_streaming_task.cancel.assert_called_once()
|
||||||
|
await satellite.handle_audio(b"test-mic")
|
||||||
|
await satellite.handle_pipeline_stop()
|
||||||
|
await pipeline_finished.wait()
|
||||||
|
|
||||||
|
await tts_finished.wait()
|
||||||
|
|
||||||
|
# Verify TTS streaming events.
|
||||||
|
# These are definitely the last two events because we blocked TTS streaming
|
||||||
|
# until after RUN_END above.
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify TTS WAV audio chunk came through
|
||||||
|
mock_client.send_voice_assistant_audio.assert_called_once_with(b"test-wav")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("socket_enabled")
|
||||||
|
async def test_pipeline_udp_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_client: APIClient,
|
||||||
|
mock_esphome_device: Callable[
|
||||||
|
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||||
|
Awaitable[MockESPHomeDevice],
|
||||||
|
],
|
||||||
|
mock_wav: bytes,
|
||||||
|
) -> None:
|
||||||
|
"""Test a complete pipeline run with legacy UDP audio.
|
||||||
|
|
||||||
|
This test is not as comprehensive as test_pipeline_api_audio since we're
|
||||||
|
mainly focused on the UDP server.
|
||||||
|
"""
|
||||||
|
conversation_id = "test-conversation-id"
|
||||||
|
media_url = "http://test.url"
|
||||||
|
media_id = "test-media-id"
|
||||||
|
|
||||||
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
|
mock_client=mock_client,
|
||||||
|
entity_info=[],
|
||||||
|
user_service=[],
|
||||||
|
states=[],
|
||||||
|
device_info={
|
||||||
|
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
| VoiceAssistantFeature.SPEAKER
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||||
|
assert satellite is not None
|
||||||
|
|
||||||
|
mic_audio_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||||
|
stt_stream = kwargs["stt_stream"]
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk in stt_stream:
|
||||||
|
chunks.append(chunk)
|
||||||
|
mic_audio_event.set()
|
||||||
|
|
||||||
|
# Verify test UDP audio
|
||||||
|
assert chunks == [b"test-mic"]
|
||||||
|
|
||||||
|
event_callback = kwargs["event_callback"]
|
||||||
|
|
||||||
|
# STT
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.STT_START,
|
||||||
|
data={"engine": "test-stt-engine", "metadata": {}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.STT_END,
|
||||||
|
data={"stt_output": {"text": "test-stt-text"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Intent
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.INTENT_START,
|
||||||
|
data={
|
||||||
|
"engine": "test-intent-engine",
|
||||||
|
"language": hass.config.language,
|
||||||
|
"intent_input": "test-intent-text",
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.INTENT_END,
|
||||||
|
data={"intent_output": {"conversation_id": conversation_id}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TTS
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.TTS_START,
|
||||||
|
data={
|
||||||
|
"engine": "test-stt-engine",
|
||||||
|
"language": hass.config.language,
|
||||||
|
"voice": "test-voice",
|
||||||
|
"tts_input": "test-tts-text",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return mock_wav audio
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.TTS_END,
|
||||||
|
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
|
||||||
|
|
||||||
|
pipeline_finished = asyncio.Event()
|
||||||
|
original_handle_pipeline_finished = satellite.handle_pipeline_finished
|
||||||
|
|
||||||
|
def handle_pipeline_finished():
|
||||||
|
original_handle_pipeline_finished()
|
||||||
|
pipeline_finished.set()
|
||||||
|
|
||||||
|
async def async_get_media_source_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
return ("wav", mock_wav)
|
||||||
|
|
||||||
|
tts_finished = asyncio.Event()
|
||||||
|
original_tts_response_finished = satellite.tts_response_finished
|
||||||
|
|
||||||
|
def tts_response_finished():
|
||||||
|
original_tts_response_finished()
|
||||||
|
tts_finished.set()
|
||||||
|
|
||||||
|
class TestProtocol(asyncio.DatagramProtocol):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.transport = None
|
||||||
|
self.data_received: list[bytes] = []
|
||||||
|
|
||||||
|
def connection_made(self, transport):
|
||||||
|
self.transport = transport
|
||||||
|
|
||||||
|
def datagram_received(self, data: bytes, addr):
|
||||||
|
self.data_received.append(data)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_get_media_source_audio",
|
||||||
|
new=async_get_media_source_audio,
|
||||||
|
),
|
||||||
|
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||||
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||||
|
):
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
port = await satellite.handle_pipeline_start(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
flags=VoiceAssistantCommandFlag(0), # stt
|
||||||
|
audio_settings=VoiceAssistantAudioSettings(),
|
||||||
|
wake_word_phrase="",
|
||||||
|
)
|
||||||
|
assert (port is not None) and (port > 0)
|
||||||
|
|
||||||
|
(
|
||||||
|
transport,
|
||||||
|
protocol,
|
||||||
|
) = await asyncio.get_running_loop().create_datagram_endpoint(
|
||||||
|
TestProtocol, remote_addr=("127.0.0.1", port)
|
||||||
|
)
|
||||||
|
assert isinstance(protocol, TestProtocol)
|
||||||
|
|
||||||
|
# Send audio over UDP
|
||||||
|
transport.sendto(b"test-mic")
|
||||||
|
|
||||||
|
# Wait for audio chunk to be delivered
|
||||||
|
await mic_audio_event.wait()
|
||||||
|
|
||||||
|
await satellite.handle_pipeline_stop()
|
||||||
|
await pipeline_finished.wait()
|
||||||
|
|
||||||
|
await tts_finished.wait()
|
||||||
|
|
||||||
|
# Verify TTS audio (from UDP)
|
||||||
|
assert protocol.data_received == [b"test-wav"]
|
||||||
|
|
||||||
|
# Check that UDP server was stopped
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
sock.setblocking(False)
|
||||||
|
sock.bind(("", port)) # will fail if UDP server is still running
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_udp_errors() -> None:
|
||||||
|
"""Test UDP protocol error conditions."""
|
||||||
|
audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||||
|
protocol = VoiceAssistantUDPServer(audio_queue)
|
||||||
|
|
||||||
|
protocol.datagram_received(b"test", ("", 0))
|
||||||
|
assert audio_queue.qsize() == 1
|
||||||
|
assert (await audio_queue.get()) == b"test"
|
||||||
|
|
||||||
|
# None will stop the pipeline
|
||||||
|
protocol.error_received(RuntimeError())
|
||||||
|
assert audio_queue.qsize() == 1
|
||||||
|
assert (await audio_queue.get()) is None
|
||||||
|
|
||||||
|
# No transport
|
||||||
|
assert protocol.transport is None
|
||||||
|
protocol.send_audio_bytes(b"test")
|
||||||
|
|
||||||
|
# No remote address
|
||||||
|
protocol.transport = Mock()
|
||||||
|
protocol.remote_addr = None
|
||||||
|
protocol.send_audio_bytes(b"test")
|
||||||
|
protocol.transport.sendto.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_timer_events(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
mock_client: APIClient,
|
||||||
|
mock_esphome_device: Callable[
|
||||||
|
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||||
|
Awaitable[MockESPHomeDevice],
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Test that injecting timer events results in the correct api client calls."""
|
||||||
|
|
||||||
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
|
mock_client=mock_client,
|
||||||
|
entity_info=[],
|
||||||
|
user_service=[],
|
||||||
|
states=[],
|
||||||
|
device_info={
|
||||||
|
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
| VoiceAssistantFeature.TIMERS
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
dev = device_registry.async_get_device(
|
||||||
|
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
||||||
|
)
|
||||||
|
|
||||||
|
total_seconds = (1 * 60 * 60) + (2 * 60) + 3
|
||||||
|
await intent_helper.async_handle(
|
||||||
|
hass,
|
||||||
|
"test",
|
||||||
|
intent_helper.INTENT_START_TIMER,
|
||||||
|
{
|
||||||
|
"name": {"value": "test timer"},
|
||||||
|
"hours": {"value": 1},
|
||||||
|
"minutes": {"value": 2},
|
||||||
|
"seconds": {"value": 3},
|
||||||
|
},
|
||||||
|
device_id=dev.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.send_voice_assistant_timer_event.assert_called_with(
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED,
|
||||||
|
ANY,
|
||||||
|
"test timer",
|
||||||
|
total_seconds,
|
||||||
|
total_seconds,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Increase timer beyond original time and check total_seconds has increased
|
||||||
|
mock_client.send_voice_assistant_timer_event.reset_mock()
|
||||||
|
|
||||||
|
total_seconds += 5 * 60
|
||||||
|
await intent_helper.async_handle(
|
||||||
|
hass,
|
||||||
|
"test",
|
||||||
|
intent_helper.INTENT_INCREASE_TIMER,
|
||||||
|
{
|
||||||
|
"name": {"value": "test timer"},
|
||||||
|
"minutes": {"value": 5},
|
||||||
|
},
|
||||||
|
device_id=dev.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.send_voice_assistant_timer_event.assert_called_with(
|
||||||
|
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED,
|
||||||
|
ANY,
|
||||||
|
"test timer",
|
||||||
|
total_seconds,
|
||||||
|
ANY,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_unknown_timer_event(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
mock_client: APIClient,
|
||||||
|
mock_esphome_device: Callable[
|
||||||
|
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||||
|
Awaitable[MockESPHomeDevice],
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Test that unknown (new) timer event types do not result in api calls."""
|
||||||
|
|
||||||
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
|
mock_client=mock_client,
|
||||||
|
entity_info=[],
|
||||||
|
user_service=[],
|
||||||
|
states=[],
|
||||||
|
device_info={
|
||||||
|
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
| VoiceAssistantFeature.TIMERS
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert mock_device.entry.unique_id is not None
|
||||||
|
dev = device_registry.async_get_device(
|
||||||
|
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
||||||
|
)
|
||||||
|
assert dev is not None
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.esphome.assist_satellite._TIMER_EVENT_TYPES.from_hass",
|
||||||
|
side_effect=KeyError,
|
||||||
|
):
|
||||||
|
await intent_helper.async_handle(
|
||||||
|
hass,
|
||||||
|
"test",
|
||||||
|
intent_helper.INTENT_START_TIMER,
|
||||||
|
{
|
||||||
|
"name": {"value": "test timer"},
|
||||||
|
"hours": {"value": 1},
|
||||||
|
"minutes": {"value": 2},
|
||||||
|
"seconds": {"value": 3},
|
||||||
|
},
|
||||||
|
device_id=dev.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.send_voice_assistant_timer_event.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_streaming_tts_errors(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_client: APIClient,
|
||||||
|
mock_esphome_device: Callable[
|
||||||
|
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||||
|
Awaitable[MockESPHomeDevice],
|
||||||
|
],
|
||||||
|
mock_wav: bytes,
|
||||||
|
) -> None:
|
||||||
|
"""Test error conditions for _stream_tts_audio function."""
|
||||||
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
|
mock_client=mock_client,
|
||||||
|
entity_info=[],
|
||||||
|
user_service=[],
|
||||||
|
states=[],
|
||||||
|
device_info={
|
||||||
|
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||||
|
assert satellite is not None
|
||||||
|
|
||||||
|
# Should not stream if not running
|
||||||
|
satellite._is_running = False
|
||||||
|
await satellite._stream_tts_audio("test-media-id")
|
||||||
|
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||||
|
satellite._is_running = True
|
||||||
|
|
||||||
|
# Should only stream WAV
|
||||||
|
async def get_mp3(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
return ("mp3", b"")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.tts.async_get_media_source_audio", new=get_mp3
|
||||||
|
):
|
||||||
|
await satellite._stream_tts_audio("test-media-id")
|
||||||
|
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||||
|
|
||||||
|
# Needs to be the correct sample rate, etc.
|
||||||
|
async def get_bad_wav(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
with io.BytesIO() as wav_io:
|
||||||
|
with wave.open(wav_io, "wb") as wav_file:
|
||||||
|
wav_file.setframerate(48000)
|
||||||
|
wav_file.setsampwidth(2)
|
||||||
|
wav_file.setnchannels(1)
|
||||||
|
wav_file.writeframes(b"test-wav")
|
||||||
|
|
||||||
|
return ("wav", wav_io.getvalue())
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.tts.async_get_media_source_audio", new=get_bad_wav
|
||||||
|
):
|
||||||
|
await satellite._stream_tts_audio("test-media-id")
|
||||||
|
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||||
|
|
||||||
|
# Check that TTS_STREAM_* events still get sent after cancel
|
||||||
|
media_fetched = asyncio.Event()
|
||||||
|
|
||||||
|
async def get_slow_wav(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
media_fetched.set()
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
return ("wav", mock_wav)
|
||||||
|
|
||||||
|
mock_client.send_voice_assistant_event.reset_mock()
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.tts.async_get_media_source_audio", new=get_slow_wav
|
||||||
|
):
|
||||||
|
task = asyncio.create_task(satellite._stream_tts_audio("test-media-id"))
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
# Wait for media to be fetched
|
||||||
|
await media_fetched.wait()
|
||||||
|
|
||||||
|
# Cancel task
|
||||||
|
task.cancel()
|
||||||
|
await task
|
||||||
|
|
||||||
|
# No audio should have gone out
|
||||||
|
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||||
|
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
|
||||||
|
|
||||||
|
# The TTS_STREAM_* events should have gone out
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
||||||
|
{},
|
||||||
|
)
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from unittest.mock import AsyncMock, call, patch
|
from unittest.mock import AsyncMock, call
|
||||||
|
|
||||||
from aioesphomeapi import (
|
from aioesphomeapi import (
|
||||||
APIClient,
|
APIClient,
|
||||||
@ -17,7 +17,6 @@ from aioesphomeapi import (
|
|||||||
UserService,
|
UserService,
|
||||||
UserServiceArg,
|
UserServiceArg,
|
||||||
UserServiceArgType,
|
UserServiceArgType,
|
||||||
VoiceAssistantFeature,
|
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -29,10 +28,6 @@ from homeassistant.components.esphome.const import (
|
|||||||
DOMAIN,
|
DOMAIN,
|
||||||
STABLE_BLE_VERSION_STR,
|
STABLE_BLE_VERSION_STR,
|
||||||
)
|
)
|
||||||
from homeassistant.components.esphome.voice_assistant import (
|
|
||||||
VoiceAssistantAPIPipeline,
|
|
||||||
VoiceAssistantUDPPipeline,
|
|
||||||
)
|
|
||||||
from homeassistant.const import (
|
from homeassistant.const import (
|
||||||
CONF_HOST,
|
CONF_HOST,
|
||||||
CONF_PASSWORD,
|
CONF_PASSWORD,
|
||||||
@ -44,7 +39,7 @@ from homeassistant.data_entry_flow import FlowResultType
|
|||||||
from homeassistant.helpers import device_registry as dr, issue_registry as ir
|
from homeassistant.helpers import device_registry as dr, issue_registry as ir
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from .conftest import _ONE_SECOND, MockESPHomeDevice
|
from .conftest import MockESPHomeDevice
|
||||||
|
|
||||||
from tests.common import MockConfigEntry, async_capture_events, async_mock_service
|
from tests.common import MockConfigEntry, async_capture_events, async_mock_service
|
||||||
|
|
||||||
@ -1214,102 +1209,3 @@ async def test_entry_missing_unique_id(
|
|||||||
await mock_esphome_device(mock_client=mock_client, mock_storage=True)
|
await mock_esphome_device(mock_client=mock_client, mock_storage=True)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
assert entry.unique_id == "11:22:33:44:55:aa"
|
assert entry.unique_id == "11:22:33:44:55:aa"
|
||||||
|
|
||||||
|
|
||||||
async def test_manager_voice_assistant_handlers_api(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_client: APIClient,
|
|
||||||
mock_esphome_device: Callable[
|
|
||||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
|
||||||
Awaitable[MockESPHomeDevice],
|
|
||||||
],
|
|
||||||
caplog: pytest.LogCaptureFixture,
|
|
||||||
mock_voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the handlers are correctly executed in manager.py."""
|
|
||||||
|
|
||||||
device: MockESPHomeDevice = await mock_esphome_device(
|
|
||||||
mock_client=mock_client,
|
|
||||||
entity_info=[],
|
|
||||||
user_service=[],
|
|
||||||
states=[],
|
|
||||||
device_info={
|
|
||||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
|
||||||
| VoiceAssistantFeature.API_AUDIO
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.esphome.manager.VoiceAssistantAPIPipeline",
|
|
||||||
new=mock_voice_assistant_api_pipeline,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
port: int | None = await device.mock_voice_assistant_handle_start(
|
|
||||||
"", 0, None, None
|
|
||||||
)
|
|
||||||
|
|
||||||
assert port == 0
|
|
||||||
|
|
||||||
port: int | None = await device.mock_voice_assistant_handle_start(
|
|
||||||
"", 0, None, None
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "Previous Voice assistant pipeline was not stopped" in caplog.text
|
|
||||||
|
|
||||||
await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND))
|
|
||||||
|
|
||||||
mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_called_with(
|
|
||||||
bytes(_ONE_SECOND)
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_voice_assistant_api_pipeline.receive_audio_bytes.reset_mock()
|
|
||||||
|
|
||||||
await device.mock_voice_assistant_handle_stop()
|
|
||||||
mock_voice_assistant_api_pipeline.handle_finished()
|
|
||||||
|
|
||||||
await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND))
|
|
||||||
|
|
||||||
mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_manager_voice_assistant_handlers_udp(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_client: APIClient,
|
|
||||||
mock_esphome_device: Callable[
|
|
||||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
|
||||||
Awaitable[MockESPHomeDevice],
|
|
||||||
],
|
|
||||||
mock_voice_assistant_udp_pipeline: VoiceAssistantUDPPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the handlers are correctly executed in manager.py."""
|
|
||||||
|
|
||||||
device: MockESPHomeDevice = await mock_esphome_device(
|
|
||||||
mock_client=mock_client,
|
|
||||||
entity_info=[],
|
|
||||||
user_service=[],
|
|
||||||
states=[],
|
|
||||||
device_info={
|
|
||||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.esphome.manager.VoiceAssistantUDPPipeline",
|
|
||||||
new=mock_voice_assistant_udp_pipeline,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
await device.mock_voice_assistant_handle_start("", 0, None, None)
|
|
||||||
|
|
||||||
mock_voice_assistant_udp_pipeline.run_pipeline.assert_called()
|
|
||||||
|
|
||||||
await device.mock_voice_assistant_handle_stop()
|
|
||||||
mock_voice_assistant_udp_pipeline.handle_finished()
|
|
||||||
|
|
||||||
mock_voice_assistant_udp_pipeline.stop.assert_called()
|
|
||||||
mock_voice_assistant_udp_pipeline.close.assert_called()
|
|
||||||
|
@ -1,964 +0,0 @@
|
|||||||
"""Test ESPHome voice assistant server."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
import io
|
|
||||||
import socket
|
|
||||||
from unittest.mock import ANY, Mock, patch
|
|
||||||
import wave
|
|
||||||
|
|
||||||
from aioesphomeapi import (
|
|
||||||
APIClient,
|
|
||||||
EntityInfo,
|
|
||||||
EntityState,
|
|
||||||
UserService,
|
|
||||||
VoiceAssistantEventType,
|
|
||||||
VoiceAssistantFeature,
|
|
||||||
VoiceAssistantTimerEventType,
|
|
||||||
)
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline import (
|
|
||||||
PipelineEvent,
|
|
||||||
PipelineEventType,
|
|
||||||
PipelineStage,
|
|
||||||
)
|
|
||||||
from homeassistant.components.assist_pipeline.error import (
|
|
||||||
PipelineNotFound,
|
|
||||||
WakeWordDetectionAborted,
|
|
||||||
WakeWordDetectionError,
|
|
||||||
)
|
|
||||||
from homeassistant.components.esphome import DomainData
|
|
||||||
from homeassistant.components.esphome.voice_assistant import (
|
|
||||||
VoiceAssistantAPIPipeline,
|
|
||||||
VoiceAssistantUDPPipeline,
|
|
||||||
)
|
|
||||||
from homeassistant.core import HomeAssistant
|
|
||||||
from homeassistant.helpers import intent as intent_helper
|
|
||||||
import homeassistant.helpers.device_registry as dr
|
|
||||||
|
|
||||||
from .conftest import _ONE_SECOND, MockESPHomeDevice
|
|
||||||
|
|
||||||
_TEST_INPUT_TEXT = "This is an input test"
|
|
||||||
_TEST_OUTPUT_TEXT = "This is an output test"
|
|
||||||
_TEST_OUTPUT_URL = "output.mp3"
|
|
||||||
_TEST_MEDIA_ID = "12345"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def voice_assistant_udp_pipeline(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
) -> VoiceAssistantUDPPipeline:
|
|
||||||
"""Return the UDP pipeline factory."""
|
|
||||||
|
|
||||||
def _voice_assistant_udp_server(entry):
|
|
||||||
entry_data = DomainData.get(hass).get_entry_data(entry)
|
|
||||||
|
|
||||||
server: VoiceAssistantUDPPipeline = None
|
|
||||||
|
|
||||||
def handle_finished():
|
|
||||||
nonlocal server
|
|
||||||
assert server is not None
|
|
||||||
server.close()
|
|
||||||
|
|
||||||
server = VoiceAssistantUDPPipeline(hass, entry_data, Mock(), handle_finished)
|
|
||||||
return server # noqa: RET504
|
|
||||||
|
|
||||||
return _voice_assistant_udp_server
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def voice_assistant_api_pipeline(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_client,
|
|
||||||
mock_voice_assistant_api_entry,
|
|
||||||
) -> VoiceAssistantAPIPipeline:
|
|
||||||
"""Return the API Pipeline factory."""
|
|
||||||
entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_api_entry)
|
|
||||||
return VoiceAssistantAPIPipeline(hass, entry_data, Mock(), Mock(), mock_client)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def voice_assistant_udp_pipeline_v1(
|
|
||||||
voice_assistant_udp_pipeline,
|
|
||||||
mock_voice_assistant_v1_entry,
|
|
||||||
) -> VoiceAssistantUDPPipeline:
|
|
||||||
"""Return the UDP pipeline."""
|
|
||||||
return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v1_entry)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def voice_assistant_udp_pipeline_v2(
|
|
||||||
voice_assistant_udp_pipeline,
|
|
||||||
mock_voice_assistant_v2_entry,
|
|
||||||
) -> VoiceAssistantUDPPipeline:
|
|
||||||
"""Return the UDP pipeline."""
|
|
||||||
return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v2_entry)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_wav() -> bytes:
|
|
||||||
"""Return one second of empty WAV audio."""
|
|
||||||
with io.BytesIO() as wav_io:
|
|
||||||
with wave.open(wav_io, "wb") as wav_file:
|
|
||||||
wav_file.setframerate(16000)
|
|
||||||
wav_file.setsampwidth(2)
|
|
||||||
wav_file.setnchannels(1)
|
|
||||||
wav_file.writeframes(bytes(_ONE_SECOND))
|
|
||||||
|
|
||||||
return wav_io.getvalue()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_events(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the pipeline function is called."""
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
|
||||||
assert device_id == "mock-device-id"
|
|
||||||
|
|
||||||
event_callback = kwargs["event_callback"]
|
|
||||||
|
|
||||||
event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.WAKE_WORD_END,
|
|
||||||
data={"wake_word_output": {}},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fake events
|
|
||||||
event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.STT_START,
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.STT_END,
|
|
||||||
data={"stt_output": {"text": _TEST_INPUT_TEXT}},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_START,
|
|
||||||
data={"tts_input": _TEST_OUTPUT_TEXT},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={"tts_output": {"url": _TEST_OUTPUT_URL}},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def handle_event(
|
|
||||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
|
||||||
) -> None:
|
|
||||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
|
||||||
assert data is not None
|
|
||||||
assert data["text"] == _TEST_INPUT_TEXT
|
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
|
||||||
assert data is not None
|
|
||||||
assert data["text"] == _TEST_OUTPUT_TEXT
|
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
|
||||||
assert data is not None
|
|
||||||
assert data["url"] == _TEST_OUTPUT_URL
|
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
|
||||||
assert data is None
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.handle_event = handle_event
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
):
|
|
||||||
voice_assistant_udp_pipeline_v1.transport = Mock()
|
|
||||||
|
|
||||||
await voice_assistant_udp_pipeline_v1.run_pipeline(
|
|
||||||
device_id="mock-device-id", conversation_id=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("socket_enabled")
|
|
||||||
async def test_udp_server(
|
|
||||||
unused_udp_port_factory: Callable[[], int],
|
|
||||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the UDP server runs and queues incoming data."""
|
|
||||||
port_to_use = unused_udp_port_factory()
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
|
|
||||||
):
|
|
||||||
port = await voice_assistant_udp_pipeline_v1.start_server()
|
|
||||||
assert port == port_to_use
|
|
||||||
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
|
|
||||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
|
|
||||||
sock.sendto(b"test", ("127.0.0.1", port))
|
|
||||||
|
|
||||||
# Give the socket some time to send/receive the data
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
while voice_assistant_udp_pipeline_v1.queue.qsize() == 0:
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.stop()
|
|
||||||
voice_assistant_udp_pipeline_v1.close()
|
|
||||||
|
|
||||||
assert voice_assistant_udp_pipeline_v1.transport.is_closing()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_udp_server_queue(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the UDP server queues incoming data."""
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.started = True
|
|
||||||
|
|
||||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
|
|
||||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
|
|
||||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 2
|
|
||||||
|
|
||||||
async for data in voice_assistant_udp_pipeline_v1._iterate_packets():
|
|
||||||
assert data == bytes(1024)
|
|
||||||
break
|
|
||||||
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 # One message removed
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.stop()
|
|
||||||
assert (
|
|
||||||
voice_assistant_udp_pipeline_v1.queue.qsize() == 2
|
|
||||||
) # An empty message added by stop
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
|
|
||||||
assert (
|
|
||||||
voice_assistant_udp_pipeline_v1.queue.qsize() == 2
|
|
||||||
) # No new messages added after stop
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.close()
|
|
||||||
|
|
||||||
# Stopping the UDP server should cause _iterate_packets to break out
|
|
||||||
# immediately without yielding any data.
|
|
||||||
has_data = False
|
|
||||||
async for _data in voice_assistant_udp_pipeline_v1._iterate_packets():
|
|
||||||
has_data = True
|
|
||||||
|
|
||||||
assert not has_data, "Server was stopped"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_api_pipeline_queue(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the API pipeline queues incoming data."""
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.started = True
|
|
||||||
|
|
||||||
assert voice_assistant_api_pipeline.queue.qsize() == 0
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
|
|
||||||
assert voice_assistant_api_pipeline.queue.qsize() == 1
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
|
|
||||||
assert voice_assistant_api_pipeline.queue.qsize() == 2
|
|
||||||
|
|
||||||
async for data in voice_assistant_api_pipeline._iterate_packets():
|
|
||||||
assert data == bytes(1024)
|
|
||||||
break
|
|
||||||
assert voice_assistant_api_pipeline.queue.qsize() == 1 # One message removed
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.stop()
|
|
||||||
assert (
|
|
||||||
voice_assistant_api_pipeline.queue.qsize() == 2
|
|
||||||
) # An empty message added by stop
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
|
|
||||||
assert (
|
|
||||||
voice_assistant_api_pipeline.queue.qsize() == 2
|
|
||||||
) # No new messages added after stop
|
|
||||||
|
|
||||||
# Stopping the API Pipeline should cause _iterate_packets to break out
|
|
||||||
# immediately without yielding any data.
|
|
||||||
has_data = False
|
|
||||||
async for _data in voice_assistant_api_pipeline._iterate_packets():
|
|
||||||
has_data = True
|
|
||||||
|
|
||||||
assert not has_data, "Pipeline was stopped"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_error_calls_handle_finished(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the handle_finished callback is called when an error occurs."""
|
|
||||||
voice_assistant_udp_pipeline_v1.handle_finished = Mock()
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.error_received(Exception())
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v1.handle_finished.assert_called()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("socket_enabled")
|
|
||||||
async def test_udp_server_multiple(
|
|
||||||
unused_udp_port_factory: Callable[[], int],
|
|
||||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the UDP server raises an error if started twice."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
|
||||||
new=unused_udp_port_factory(),
|
|
||||||
):
|
|
||||||
await voice_assistant_udp_pipeline_v1.start_server()
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
|
||||||
new=unused_udp_port_factory(),
|
|
||||||
),
|
|
||||||
pytest.raises(RuntimeError),
|
|
||||||
):
|
|
||||||
await voice_assistant_udp_pipeline_v1.start_server()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("socket_enabled")
|
|
||||||
async def test_udp_server_after_stopped(
|
|
||||||
unused_udp_port_factory: Callable[[], int],
|
|
||||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the UDP server raises an error if started after stopped."""
|
|
||||||
voice_assistant_udp_pipeline_v1.close()
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
|
||||||
new=unused_udp_port_factory(),
|
|
||||||
),
|
|
||||||
pytest.raises(RuntimeError),
|
|
||||||
):
|
|
||||||
await voice_assistant_udp_pipeline_v1.start_server()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_events_converted_correctly(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the pipeline events produce the correct data to send to the device."""
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts",
|
|
||||||
):
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.STT_START,
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.STT_END,
|
|
||||||
data={"stt_output": {"text": "text"}},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": "text"}
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.INTENT_START,
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.INTENT_END,
|
|
||||||
data={
|
|
||||||
"intent_output": {
|
|
||||||
"conversation_id": "conversation-id",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END,
|
|
||||||
{"conversation_id": "conversation-id"},
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_START,
|
|
||||||
data={"tts_input": "text"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": "text"}
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={"tts_output": {"url": "url", "media_id": "media-id"}},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END, {"url": "url"}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_unknown_event_type(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the API pipeline does not call handle_event for unknown events."""
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type="unknown-event",
|
|
||||||
data={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert not voice_assistant_api_pipeline.handle_event.called
|
|
||||||
|
|
||||||
|
|
||||||
async def test_error_event_type(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the API pipeline calls event handler with error."""
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.ERROR,
|
|
||||||
data={"code": "code", "message": "message"},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
|
||||||
{"code": "code", "message": "message"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_not_called(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the UDP server with a v1 device does not call _send_tts."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
|
||||||
) as mock_send_tts:
|
|
||||||
voice_assistant_udp_pipeline_v1._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={
|
|
||||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_send_tts.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_called_udp(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the UDP server with a v2 device calls _send_tts."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
|
||||||
) as mock_send_tts:
|
|
||||||
voice_assistant_udp_pipeline_v2._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={
|
|
||||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_called_api(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the API pipeline calls _send_tts."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
|
||||||
) as mock_send_tts:
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={
|
|
||||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_not_called_when_empty(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
|
||||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test the pipelines do not call _send_tts when the output is empty."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
|
||||||
) as mock_send_tts:
|
|
||||||
voice_assistant_udp_pipeline_v1._event_callback(
|
|
||||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_send_tts.assert_not_called()
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v2._event_callback(
|
|
||||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_send_tts.assert_not_called()
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_send_tts.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_udp(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
|
||||||
mock_wav: bytes,
|
|
||||||
) -> None:
|
|
||||||
"""Test the UDP server calls sendto to transmit audio data to device."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
|
||||||
return_value=("wav", mock_wav),
|
|
||||||
):
|
|
||||||
voice_assistant_udp_pipeline_v2.started = True
|
|
||||||
voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
|
||||||
with patch.object(
|
|
||||||
voice_assistant_udp_pipeline_v2.transport, "is_closing", return_value=False
|
|
||||||
):
|
|
||||||
voice_assistant_udp_pipeline_v2._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={
|
|
||||||
"tts_output": {
|
|
||||||
"media_id": _TEST_MEDIA_ID,
|
|
||||||
"url": _TEST_OUTPUT_URL,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await voice_assistant_udp_pipeline_v2._tts_done.wait()
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v2.transport.sendto.assert_called()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_api(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_client: APIClient,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
mock_wav: bytes,
|
|
||||||
) -> None:
|
|
||||||
"""Test the API pipeline calls cli.send_voice_assistant_audio to transmit audio data to device."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
|
||||||
return_value=("wav", mock_wav),
|
|
||||||
):
|
|
||||||
voice_assistant_api_pipeline.started = True
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={
|
|
||||||
"tts_output": {
|
|
||||||
"media_id": _TEST_MEDIA_ID,
|
|
||||||
"url": _TEST_OUTPUT_URL,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await voice_assistant_api_pipeline._tts_done.wait()
|
|
||||||
|
|
||||||
mock_client.send_voice_assistant_audio.assert_called()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_wrong_sample_rate(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that only 16000Hz audio will be streamed."""
|
|
||||||
with io.BytesIO() as wav_io:
|
|
||||||
with wave.open(wav_io, "wb") as wav_file:
|
|
||||||
wav_file.setframerate(22050)
|
|
||||||
wav_file.setsampwidth(2)
|
|
||||||
wav_file.setnchannels(1)
|
|
||||||
wav_file.writeframes(bytes(_ONE_SECOND))
|
|
||||||
|
|
||||||
wav_bytes = wav_io.getvalue()
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
|
||||||
return_value=("wav", wav_bytes),
|
|
||||||
):
|
|
||||||
voice_assistant_api_pipeline.started = True
|
|
||||||
voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={
|
|
||||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert voice_assistant_api_pipeline._tts_task is not None
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await voice_assistant_api_pipeline._tts_task
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_wrong_format(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that only WAV audio will be streamed."""
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
|
||||||
return_value=("raw", bytes(1024)),
|
|
||||||
),
|
|
||||||
):
|
|
||||||
voice_assistant_api_pipeline.started = True
|
|
||||||
voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={
|
|
||||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert voice_assistant_api_pipeline._tts_task is not None
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
await voice_assistant_api_pipeline._tts_task
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_not_started(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
|
||||||
mock_wav: bytes,
|
|
||||||
) -> None:
|
|
||||||
"""Test the UDP server does not call sendto when not started."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
|
||||||
return_value=("wav", mock_wav),
|
|
||||||
):
|
|
||||||
voice_assistant_udp_pipeline_v2.started = False
|
|
||||||
voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v2._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={
|
|
||||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await voice_assistant_udp_pipeline_v2._tts_done.wait()
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v2.transport.sendto.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_transport_none(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
|
||||||
mock_wav: bytes,
|
|
||||||
caplog: pytest.LogCaptureFixture,
|
|
||||||
) -> None:
|
|
||||||
"""Test the UDP server does not call sendto when transport is None."""
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
|
||||||
return_value=("wav", mock_wav),
|
|
||||||
):
|
|
||||||
voice_assistant_udp_pipeline_v2.started = True
|
|
||||||
voice_assistant_udp_pipeline_v2.transport = None
|
|
||||||
|
|
||||||
voice_assistant_udp_pipeline_v2._event_callback(
|
|
||||||
PipelineEvent(
|
|
||||||
type=PipelineEventType.TTS_END,
|
|
||||||
data={
|
|
||||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await voice_assistant_udp_pipeline_v2._tts_done.wait()
|
|
||||||
|
|
||||||
assert "No transport to send audio to" in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
async def test_wake_word(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the pipeline is set to start with Wake word."""
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs):
|
|
||||||
assert start_stage == PipelineStage.WAKE_WORD
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
),
|
|
||||||
patch("asyncio.Event.wait"), # TTS wait event
|
|
||||||
):
|
|
||||||
await voice_assistant_api_pipeline.run_pipeline(
|
|
||||||
device_id="mock-device-id",
|
|
||||||
conversation_id=None,
|
|
||||||
flags=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_wake_word_exception(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the pipeline is set to start with Wake word."""
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
||||||
raise WakeWordDetectionError("pipeline-not-found", "Pipeline not found")
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
):
|
|
||||||
|
|
||||||
def handle_event(
|
|
||||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
|
||||||
) -> None:
|
|
||||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
|
||||||
assert data is not None
|
|
||||||
assert data["code"] == "pipeline-not-found"
|
|
||||||
assert data["message"] == "Pipeline not found"
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.handle_event = handle_event
|
|
||||||
|
|
||||||
await voice_assistant_api_pipeline.run_pipeline(
|
|
||||||
device_id="mock-device-id",
|
|
||||||
conversation_id=None,
|
|
||||||
flags=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_wake_word_abort_exception(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the pipeline is set to start with Wake word."""
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
||||||
raise WakeWordDetectionAborted
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
),
|
|
||||||
patch.object(voice_assistant_api_pipeline, "handle_event") as mock_handle_event,
|
|
||||||
):
|
|
||||||
await voice_assistant_api_pipeline.run_pipeline(
|
|
||||||
device_id="mock-device-id",
|
|
||||||
conversation_id=None,
|
|
||||||
flags=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_handle_event.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_timer_events(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
device_registry: dr.DeviceRegistry,
|
|
||||||
mock_client: APIClient,
|
|
||||||
mock_esphome_device: Callable[
|
|
||||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
|
||||||
Awaitable[MockESPHomeDevice],
|
|
||||||
],
|
|
||||||
) -> None:
|
|
||||||
"""Test that injecting timer events results in the correct api client calls."""
|
|
||||||
|
|
||||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
|
||||||
mock_client=mock_client,
|
|
||||||
entity_info=[],
|
|
||||||
user_service=[],
|
|
||||||
states=[],
|
|
||||||
device_info={
|
|
||||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
|
||||||
| VoiceAssistantFeature.TIMERS
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
dev = device_registry.async_get_device(
|
|
||||||
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
|
||||||
)
|
|
||||||
|
|
||||||
total_seconds = (1 * 60 * 60) + (2 * 60) + 3
|
|
||||||
await intent_helper.async_handle(
|
|
||||||
hass,
|
|
||||||
"test",
|
|
||||||
intent_helper.INTENT_START_TIMER,
|
|
||||||
{
|
|
||||||
"name": {"value": "test timer"},
|
|
||||||
"hours": {"value": 1},
|
|
||||||
"minutes": {"value": 2},
|
|
||||||
"seconds": {"value": 3},
|
|
||||||
},
|
|
||||||
device_id=dev.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_client.send_voice_assistant_timer_event.assert_called_with(
|
|
||||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED,
|
|
||||||
ANY,
|
|
||||||
"test timer",
|
|
||||||
total_seconds,
|
|
||||||
total_seconds,
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Increase timer beyond original time and check total_seconds has increased
|
|
||||||
mock_client.send_voice_assistant_timer_event.reset_mock()
|
|
||||||
|
|
||||||
total_seconds += 5 * 60
|
|
||||||
await intent_helper.async_handle(
|
|
||||||
hass,
|
|
||||||
"test",
|
|
||||||
intent_helper.INTENT_INCREASE_TIMER,
|
|
||||||
{
|
|
||||||
"name": {"value": "test timer"},
|
|
||||||
"minutes": {"value": 5},
|
|
||||||
},
|
|
||||||
device_id=dev.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_client.send_voice_assistant_timer_event.assert_called_with(
|
|
||||||
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED,
|
|
||||||
ANY,
|
|
||||||
"test timer",
|
|
||||||
total_seconds,
|
|
||||||
ANY,
|
|
||||||
True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_unknown_timer_event(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
device_registry: dr.DeviceRegistry,
|
|
||||||
mock_client: APIClient,
|
|
||||||
mock_esphome_device: Callable[
|
|
||||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
|
||||||
Awaitable[MockESPHomeDevice],
|
|
||||||
],
|
|
||||||
) -> None:
|
|
||||||
"""Test that unknown (new) timer event types do not result in api calls."""
|
|
||||||
|
|
||||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
|
||||||
mock_client=mock_client,
|
|
||||||
entity_info=[],
|
|
||||||
user_service=[],
|
|
||||||
states=[],
|
|
||||||
device_info={
|
|
||||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
|
||||||
| VoiceAssistantFeature.TIMERS
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
dev = device_registry.async_get_device(
|
|
||||||
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant._TIMER_EVENT_TYPES.from_hass",
|
|
||||||
side_effect=KeyError,
|
|
||||||
):
|
|
||||||
await intent_helper.async_handle(
|
|
||||||
hass,
|
|
||||||
"test",
|
|
||||||
intent_helper.INTENT_START_TIMER,
|
|
||||||
{
|
|
||||||
"name": {"value": "test timer"},
|
|
||||||
"hours": {"value": 1},
|
|
||||||
"minutes": {"value": 2},
|
|
||||||
"seconds": {"value": 3},
|
|
||||||
},
|
|
||||||
device_id=dev.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_client.send_voice_assistant_timer_event.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
async def test_invalid_pipeline_id(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the pipeline is set to start with Wake word."""
|
|
||||||
|
|
||||||
invalid_pipeline_id = "invalid-pipeline-id"
|
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
||||||
raise PipelineNotFound(
|
|
||||||
"pipeline_not_found", f"Pipeline {invalid_pipeline_id} not found"
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
):
|
|
||||||
|
|
||||||
def handle_event(
|
|
||||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
|
||||||
) -> None:
|
|
||||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
|
||||||
assert data is not None
|
|
||||||
assert data["code"] == "pipeline_not_found"
|
|
||||||
assert data["message"] == f"Pipeline {invalid_pipeline_id} not found"
|
|
||||||
|
|
||||||
voice_assistant_api_pipeline.handle_event = handle_event
|
|
||||||
|
|
||||||
await voice_assistant_api_pipeline.run_pipeline(
|
|
||||||
device_id="mock-device-id",
|
|
||||||
conversation_id=None,
|
|
||||||
flags=2,
|
|
||||||
)
|
|
Loading…
x
Reference in New Issue
Block a user