diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py new file mode 100644 index 00000000000..48bb9ec5507 --- /dev/null +++ b/homeassistant/components/esphome/assist_satellite.py @@ -0,0 +1,504 @@ +"""Support for assist satellites in ESPHome.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterable +from functools import partial +import io +import logging +import socket +from typing import Any, cast +import wave + +from aioesphomeapi import ( + VoiceAssistantAudioSettings, + VoiceAssistantCommandFlag, + VoiceAssistantEventType, + VoiceAssistantFeature, + VoiceAssistantTimerEventType, +) + +from homeassistant.components import assist_satellite, tts +from homeassistant.components.assist_pipeline import ( + PipelineEvent, + PipelineEventType, + PipelineStage, +) +from homeassistant.components.intent import async_register_timer_handler +from homeassistant.components.intent.timers import TimerEventType, TimerInfo +from homeassistant.components.media_player import async_process_play_media_url +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import EntityCategory, Platform +from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN +from .entity import EsphomeAssistEntity +from .entry_data import ESPHomeConfigEntry, RuntimeEntryData +from .enum_mapper import EsphomeEnumMapper + +_LOGGER = logging.getLogger(__name__) + +_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[ + VoiceAssistantEventType, PipelineEventType +] = EsphomeEnumMapper( + { + VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR, + VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START, + VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END, + VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START, + VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END, + VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START, + VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END, + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START, + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END, + VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START, + VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END, + VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START, + VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END, + } +) + +_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = ( + EsphomeEnumMapper( + { + VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED, + VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED, + VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED, + VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED, + } + ) +) + + +async def async_setup_entry( + hass: HomeAssistant, + entry: ESPHomeConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Assist satellite entity.""" + entry_data = entry.runtime_data + assert entry_data.device_info is not None + if entry_data.device_info.voice_assistant_feature_flags_compat( + entry_data.api_version + ): + async_add_entities( + [ + EsphomeAssistSatellite(entry, entry_data), + ] + ) + + +class EsphomeAssistSatellite( + EsphomeAssistEntity, assist_satellite.AssistSatelliteEntity +): + """Satellite running ESPHome.""" + + entity_description = assist_satellite.AssistSatelliteEntityDescription( + key="assist_satellite", + translation_key="assist_satellite", + entity_category=EntityCategory.CONFIG, + ) + + def __init__( + self, + config_entry: ConfigEntry, + entry_data: RuntimeEntryData, + ) -> None: + """Initialize satellite.""" + super().__init__(entry_data) + + self.config_entry = config_entry + self.entry_data = entry_data + self.cli = self.entry_data.client + + self._is_running: bool = True + self._pipeline_task: asyncio.Task | None = None + self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() + self._tts_streaming_task: asyncio.Task | None = None + self._udp_server: VoiceAssistantUDPServer | None = None + + @property + def pipeline_entity_id(self) -> str | None: + """Return the entity ID of the pipeline to use for the next conversation.""" + assert self.entry_data.device_info is not None + ent_reg = er.async_get(self.hass) + return ent_reg.async_get_entity_id( + Platform.SELECT, + DOMAIN, + f"{self.entry_data.device_info.mac_address}-pipeline", + ) + + @property + def vad_sensitivity_entity_id(self) -> str | None: + """Return the entity ID of the VAD sensitivity to use for the next conversation.""" + assert self.entry_data.device_info is not None + ent_reg = er.async_get(self.hass) + return ent_reg.async_get_entity_id( + Platform.SELECT, + DOMAIN, + f"{self.entry_data.device_info.mac_address}-vad_sensitivity", + ) + + async def async_added_to_hass(self) -> None: + """Run when entity about to be added to hass.""" + await super().async_added_to_hass() + + assert self.entry_data.device_info is not None + feature_flags = ( + self.entry_data.device_info.voice_assistant_feature_flags_compat( + self.entry_data.api_version + ) + ) + if feature_flags & VoiceAssistantFeature.API_AUDIO: + # TCP audio + self.entry_data.disconnect_callbacks.add( + self.cli.subscribe_voice_assistant( + handle_start=self.handle_pipeline_start, + handle_stop=self.handle_pipeline_stop, + handle_audio=self.handle_audio, + ) + ) + else: + # UDP audio + self.entry_data.disconnect_callbacks.add( + self.cli.subscribe_voice_assistant( + handle_start=self.handle_pipeline_start, + handle_stop=self.handle_pipeline_stop, + ) + ) + + if feature_flags & VoiceAssistantFeature.TIMERS: + # Device supports timers + assert (self.registry_entry is not None) and ( + self.registry_entry.device_id is not None + ) + self.entry_data.disconnect_callbacks.add( + async_register_timer_handler( + self.hass, self.registry_entry.device_id, self.handle_timer_event + ) + ) + + async def async_will_remove_from_hass(self) -> None: + """Run when entity will be removed from hass.""" + await super().async_will_remove_from_hass() + + self._is_running = False + self._stop_pipeline() + + def on_pipeline_event(self, event: PipelineEvent) -> None: + """Handle pipeline events.""" + try: + event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type) + except KeyError: + _LOGGER.debug("Received unknown pipeline event type: %s", event.type) + return + + data_to_send: dict[str, Any] = {} + if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: + self.entry_data.async_set_assist_pipeline_state(True) + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: + assert event.data is not None + data_to_send = {"text": event.data["stt_output"]["text"]} + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: + assert event.data is not None + data_to_send = { + "conversation_id": event.data["intent_output"]["conversation_id"] or "", + } + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: + assert event.data is not None + data_to_send = {"text": event.data["tts_input"]} + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: + assert event.data is not None + if tts_output := event.data["tts_output"]: + path = tts_output["url"] + url = async_process_play_media_url(self.hass, path) + data_to_send = {"url": url} + + assert self.entry_data.device_info is not None + feature_flags = ( + self.entry_data.device_info.voice_assistant_feature_flags_compat( + self.entry_data.api_version + ) + ) + if feature_flags & VoiceAssistantFeature.SPEAKER: + media_id = tts_output["media_id"] + self._tts_streaming_task = ( + self.config_entry.async_create_background_task( + self.hass, + self._stream_tts_audio(media_id), + "esphome_voice_assistant_tts", + ) + ) + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: + assert event.data is not None + if not event.data["wake_word_output"]: + event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR + data_to_send = { + "code": "no_wake_word", + "message": "No wake word detected", + } + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: + assert event.data is not None + data_to_send = { + "code": event.data["code"], + "message": event.data["message"], + } + + self.cli.send_voice_assistant_event(event_type, data_to_send) + + async def handle_pipeline_start( + self, + conversation_id: str, + flags: int, + audio_settings: VoiceAssistantAudioSettings, + wake_word_phrase: str | None, + ) -> int | None: + """Handle pipeline run request.""" + # Clear audio queue + while not self._audio_queue.empty(): + await self._audio_queue.get() + + if self._tts_streaming_task is not None: + # Cancel current TTS response + self._tts_streaming_task.cancel() + self._tts_streaming_task = None + + # API or UDP output audio + port: int = 0 + assert self.entry_data.device_info is not None + feature_flags = ( + self.entry_data.device_info.voice_assistant_feature_flags_compat( + self.entry_data.api_version + ) + ) + if (feature_flags & VoiceAssistantFeature.SPEAKER) and not ( + feature_flags & VoiceAssistantFeature.API_AUDIO + ): + port = await self._start_udp_server() + _LOGGER.debug("Started UDP server on port %s", port) + + # Device triggered pipeline (wake word, etc.) + if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD: + start_stage = PipelineStage.WAKE_WORD + else: + start_stage = PipelineStage.STT + + end_stage = PipelineStage.TTS + + # Run the pipeline + _LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage) + self.entry_data.async_set_assist_pipeline_state(True) + self._pipeline_task = self.config_entry.async_create_background_task( + self.hass, + self.async_accept_pipeline_from_satellite( + audio_stream=self._wrap_audio_stream(), + start_stage=start_stage, + end_stage=end_stage, + wake_word_phrase=wake_word_phrase, + ), + "esphome_assist_satellite_pipeline", + ) + self._pipeline_task.add_done_callback( + lambda _future: self.handle_pipeline_finished() + ) + + return port + + async def handle_audio(self, data: bytes) -> None: + """Handle incoming audio chunk from API.""" + self._audio_queue.put_nowait(data) + + async def handle_pipeline_stop(self) -> None: + """Handle request for pipeline to stop.""" + self._stop_pipeline() + + def handle_pipeline_finished(self) -> None: + """Handle when pipeline has finished running.""" + self.entry_data.async_set_assist_pipeline_state(False) + self._stop_udp_server() + _LOGGER.debug("Pipeline finished") + + def handle_timer_event( + self, event_type: TimerEventType, timer_info: TimerInfo + ) -> None: + """Handle timer events.""" + try: + native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type) + except KeyError: + _LOGGER.debug("Received unknown timer event type: %s", event_type) + return + + self.cli.send_voice_assistant_timer_event( + native_event_type, + timer_info.id, + timer_info.name, + timer_info.created_seconds, + timer_info.seconds_left, + timer_info.is_active, + ) + + async def _stream_tts_audio( + self, + media_id: str, + sample_rate: int = 16000, + sample_width: int = 2, + sample_channels: int = 1, + samples_per_chunk: int = 512, + ) -> None: + """Stream TTS audio chunks to device via API or UDP.""" + self.cli.send_voice_assistant_event( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {} + ) + + try: + if not self._is_running: + return + + extension, data = await tts.async_get_media_source_audio( + self.hass, + media_id, + ) + + if extension != "wav": + _LOGGER.error("Only WAV audio can be streamed, got %s", extension) + return + + with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: + if ( + (wav_file.getframerate() != sample_rate) + or (wav_file.getsampwidth() != sample_width) + or (wav_file.getnchannels() != sample_channels) + ): + _LOGGER.error("Can only stream 16Khz 16-bit mono WAV") + return + + _LOGGER.debug("Streaming %s audio samples", wav_file.getnframes()) + + while self._is_running: + chunk = wav_file.readframes(samples_per_chunk) + if not chunk: + break + + if self._udp_server is not None: + self._udp_server.send_audio_bytes(chunk) + else: + self.cli.send_voice_assistant_audio(chunk) + + # Wait for 90% of the duration of the audio that was + # sent for it to be played. This will overrun the + # device's buffer for very long audio, so using a media + # player is preferred. + samples_in_chunk = len(chunk) // (sample_width * sample_channels) + seconds_in_chunk = samples_in_chunk / sample_rate + await asyncio.sleep(seconds_in_chunk * 0.9) + except asyncio.CancelledError: + return # Don't trigger state change + finally: + self.cli.send_voice_assistant_event( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {} + ) + + # State change + self.tts_response_finished() + + async def _wrap_audio_stream(self) -> AsyncIterable[bytes]: + """Yield audio chunks from the queue until None.""" + while True: + chunk = await self._audio_queue.get() + if not chunk: + break + + yield chunk + + def _stop_pipeline(self) -> None: + """Request pipeline to be stopped.""" + self._audio_queue.put_nowait(None) + _LOGGER.debug("Requested pipeline stop") + + async def _start_udp_server(self) -> int: + """Start a UDP server on a random free port.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind(("", 0)) # random free port + + ( + _transport, + protocol, + ) = await asyncio.get_running_loop().create_datagram_endpoint( + partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock + ) + + assert isinstance(protocol, VoiceAssistantUDPServer) + self._udp_server = protocol + + # Return port + return cast(int, sock.getsockname()[1]) + + def _stop_udp_server(self) -> None: + """Stop the UDP server if it's running.""" + if self._udp_server is None: + return + + try: + self._udp_server.close() + finally: + self._udp_server = None + + _LOGGER.debug("Stopped UDP server") + + +class VoiceAssistantUDPServer(asyncio.DatagramProtocol): + """Receive UDP packets and forward them to the audio queue.""" + + transport: asyncio.DatagramTransport | None = None + remote_addr: tuple[str, int] | None = None + + def __init__( + self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any + ) -> None: + """Initialize protocol.""" + super().__init__(*args, **kwargs) + self._audio_queue = audio_queue + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """Store transport for later use.""" + self.transport = cast(asyncio.DatagramTransport, transport) + + def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: + """Handle incoming UDP packet.""" + if self.remote_addr is None: + self.remote_addr = addr + + self._audio_queue.put_nowait(data) + + def error_received(self, exc: Exception) -> None: + """Handle when a send or receive operation raises an OSError. + + (Other than BlockingIOError or InterruptedError.) + """ + _LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc) + + # Stop pipeline + self._audio_queue.put_nowait(None) + + def close(self) -> None: + """Close the receiver.""" + if self.transport is not None: + self.transport.close() + + self.remote_addr = None + + def send_audio_bytes(self, data: bytes) -> None: + """Send bytes to the device via UDP.""" + if self.transport is None: + _LOGGER.error("No transport to send audio to") + return + + if self.remote_addr is None: + _LOGGER.error("No address to send audio to") + return + + self.transport.sendto(data, self.remote_addr) diff --git a/homeassistant/components/esphome/manager.py b/homeassistant/components/esphome/manager.py index 93e8d7b5bc2..09c3cc3b7cb 100644 --- a/homeassistant/components/esphome/manager.py +++ b/homeassistant/components/esphome/manager.py @@ -20,19 +20,17 @@ from aioesphomeapi import ( RequiresEncryptionAPIError, UserService, UserServiceArgType, - VoiceAssistantAudioSettings, - VoiceAssistantFeature, ) from awesomeversion import AwesomeVersion import voluptuous as vol from homeassistant.components import tag, zeroconf -from homeassistant.components.intent import async_register_timer_handler from homeassistant.const import ( ATTR_DEVICE_ID, CONF_MODE, EVENT_HOMEASSISTANT_CLOSE, EVENT_LOGGING_CHANGED, + Platform, ) from homeassistant.core import ( Event, @@ -73,12 +71,6 @@ from .domain_data import DomainData # Import config flow so that it's added to the registry from .entry_data import ESPHomeConfigEntry, RuntimeEntryData -from .voice_assistant import ( - VoiceAssistantAPIPipeline, - VoiceAssistantPipeline, - VoiceAssistantUDPPipeline, - handle_timer_event, -) _LOGGER = logging.getLogger(__name__) @@ -149,7 +141,6 @@ class ESPHomeManager: "cli", "device_id", "domain_data", - "voice_assistant_pipeline", "reconnect_logic", "zeroconf_instance", "entry_data", @@ -173,7 +164,6 @@ class ESPHomeManager: self.cli = cli self.device_id: str | None = None self.domain_data = domain_data - self.voice_assistant_pipeline: VoiceAssistantPipeline | None = None self.reconnect_logic: ReconnectLogic | None = None self.zeroconf_instance = zeroconf_instance self.entry_data = entry.runtime_data @@ -338,77 +328,6 @@ class ESPHomeManager: entity_id, attribute, self.hass.states.get(entity_id) ) - def _handle_pipeline_finished(self) -> None: - self.entry_data.async_set_assist_pipeline_state(False) - - if self.voice_assistant_pipeline is not None: - if isinstance(self.voice_assistant_pipeline, VoiceAssistantUDPPipeline): - self.voice_assistant_pipeline.close() - self.voice_assistant_pipeline = None - - async def _handle_pipeline_start( - self, - conversation_id: str, - flags: int, - audio_settings: VoiceAssistantAudioSettings, - wake_word_phrase: str | None, - ) -> int | None: - """Start a voice assistant pipeline.""" - if self.voice_assistant_pipeline is not None: - _LOGGER.warning("Previous Voice assistant pipeline was not stopped") - self.voice_assistant_pipeline.stop() - self.voice_assistant_pipeline = None - - hass = self.hass - assert self.entry_data.device_info is not None - if ( - self.entry_data.device_info.voice_assistant_feature_flags_compat( - self.entry_data.api_version - ) - & VoiceAssistantFeature.API_AUDIO - ): - self.voice_assistant_pipeline = VoiceAssistantAPIPipeline( - hass, - self.entry_data, - self.cli.send_voice_assistant_event, - self._handle_pipeline_finished, - self.cli, - ) - port = 0 - else: - self.voice_assistant_pipeline = VoiceAssistantUDPPipeline( - hass, - self.entry_data, - self.cli.send_voice_assistant_event, - self._handle_pipeline_finished, - ) - port = await self.voice_assistant_pipeline.start_server() - - assert self.device_id is not None, "Device ID must be set" - hass.async_create_background_task( - self.voice_assistant_pipeline.run_pipeline( - device_id=self.device_id, - conversation_id=conversation_id or None, - flags=flags, - audio_settings=audio_settings, - wake_word_phrase=wake_word_phrase, - ), - "esphome.voice_assistant_pipeline.run_pipeline", - ) - - return port - - async def _handle_pipeline_stop(self) -> None: - """Stop a voice assistant pipeline.""" - if self.voice_assistant_pipeline is not None: - self.voice_assistant_pipeline.stop() - - async def _handle_audio(self, data: bytes) -> None: - if self.voice_assistant_pipeline is None: - return - assert isinstance(self.voice_assistant_pipeline, VoiceAssistantAPIPipeline) - self.voice_assistant_pipeline.receive_audio_bytes(data) - async def on_connect(self) -> None: """Subscribe to states and list entities on successful API login.""" try: @@ -509,29 +428,14 @@ class ESPHomeManager: ) ) - flags = device_info.voice_assistant_feature_flags_compat(api_version) - if flags: - if flags & VoiceAssistantFeature.API_AUDIO: - entry_data.disconnect_callbacks.add( - cli.subscribe_voice_assistant( - handle_start=self._handle_pipeline_start, - handle_stop=self._handle_pipeline_stop, - handle_audio=self._handle_audio, - ) - ) - else: - entry_data.disconnect_callbacks.add( - cli.subscribe_voice_assistant( - handle_start=self._handle_pipeline_start, - handle_stop=self._handle_pipeline_stop, - ) - ) - if flags & VoiceAssistantFeature.TIMERS: - entry_data.disconnect_callbacks.add( - async_register_timer_handler( - hass, self.device_id, partial(handle_timer_event, cli) - ) - ) + if device_info.voice_assistant_feature_flags_compat(api_version) and ( + Platform.ASSIST_SATELLITE not in entry_data.loaded_platforms + ): + # Create assist satellite entity + await self.hass.config_entries.async_forward_entry_setups( + self.entry, [Platform.ASSIST_SATELLITE] + ) + entry_data.loaded_platforms.add(Platform.ASSIST_SATELLITE) cli.subscribe_states(entry_data.async_update_state) cli.subscribe_service_calls(self.async_on_service_call) diff --git a/homeassistant/components/esphome/voice_assistant.py b/homeassistant/components/esphome/voice_assistant.py deleted file mode 100644 index eb55be2ced6..00000000000 --- a/homeassistant/components/esphome/voice_assistant.py +++ /dev/null @@ -1,479 +0,0 @@ -"""ESPHome voice assistant support.""" - -from __future__ import annotations - -import asyncio -from collections.abc import AsyncIterable, Callable -import io -import logging -import socket -from typing import cast -import wave - -from aioesphomeapi import ( - APIClient, - VoiceAssistantAudioSettings, - VoiceAssistantCommandFlag, - VoiceAssistantEventType, - VoiceAssistantFeature, - VoiceAssistantTimerEventType, -) - -from homeassistant.components import stt, tts -from homeassistant.components.assist_pipeline import ( - AudioSettings, - PipelineEvent, - PipelineEventType, - PipelineNotFound, - PipelineStage, - WakeWordSettings, - async_pipeline_from_audio_stream, - select as pipeline_select, -) -from homeassistant.components.assist_pipeline.error import ( - WakeWordDetectionAborted, - WakeWordDetectionError, -) -from homeassistant.components.assist_pipeline.vad import VadSensitivity -from homeassistant.components.intent.timers import TimerEventType, TimerInfo -from homeassistant.components.media_player import async_process_play_media_url -from homeassistant.core import Context, HomeAssistant, callback - -from .const import DOMAIN -from .entry_data import RuntimeEntryData -from .enum_mapper import EsphomeEnumMapper - -_LOGGER = logging.getLogger(__name__) - -UDP_PORT = 0 # Set to 0 to let the OS pick a free random port -UDP_MAX_PACKET_SIZE = 1024 - -_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[ - VoiceAssistantEventType, PipelineEventType -] = EsphomeEnumMapper( - { - VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR, - VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START, - VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END, - VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START, - VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END, - VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START, - VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END, - VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START, - VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END, - VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START, - VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END, - VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START, - VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END, - } -) - -_TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = ( - EsphomeEnumMapper( - { - VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED, - VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED, - VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED, - VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED, - } - ) -) - - -class VoiceAssistantPipeline: - """Base abstract pipeline class.""" - - started = False - stop_requested = False - - def __init__( - self, - hass: HomeAssistant, - entry_data: RuntimeEntryData, - handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], - handle_finished: Callable[[], None], - ) -> None: - """Initialize the pipeline.""" - self.context = Context() - self.hass = hass - self.entry_data = entry_data - assert entry_data.device_info is not None - self.device_info = entry_data.device_info - - self.queue: asyncio.Queue[bytes] = asyncio.Queue() - self.handle_event = handle_event - self.handle_finished = handle_finished - self._tts_done = asyncio.Event() - self._tts_task: asyncio.Task | None = None - - @property - def is_running(self) -> bool: - """True if the pipeline is started and hasn't been asked to stop.""" - return self.started and (not self.stop_requested) - - async def _iterate_packets(self) -> AsyncIterable[bytes]: - """Iterate over incoming packets.""" - while data := await self.queue.get(): - if not self.is_running: - break - - yield data - - def _event_callback(self, event: PipelineEvent) -> None: - """Handle pipeline events.""" - - try: - event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type) - except KeyError: - _LOGGER.debug("Received unknown pipeline event type: %s", event.type) - return - - data_to_send = None - error = False - if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: - self.entry_data.async_set_assist_pipeline_state(True) - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: - assert event.data is not None - data_to_send = {"text": event.data["stt_output"]["text"]} - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: - assert event.data is not None - data_to_send = { - "conversation_id": event.data["intent_output"]["conversation_id"] or "", - } - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: - assert event.data is not None - data_to_send = {"text": event.data["tts_input"]} - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: - assert event.data is not None - tts_output = event.data["tts_output"] - if tts_output: - path = tts_output["url"] - url = async_process_play_media_url(self.hass, path) - data_to_send = {"url": url} - - if ( - self.device_info.voice_assistant_feature_flags_compat( - self.entry_data.api_version - ) - & VoiceAssistantFeature.SPEAKER - ): - media_id = tts_output["media_id"] - self._tts_task = self.hass.async_create_background_task( - self._send_tts(media_id), "esphome_voice_assistant_tts" - ) - else: - self._tts_done.set() - else: - # Empty TTS response - data_to_send = {} - self._tts_done.set() - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: - assert event.data is not None - if not event.data["wake_word_output"]: - event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR - data_to_send = { - "code": "no_wake_word", - "message": "No wake word detected", - } - error = True - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: - assert event.data is not None - data_to_send = { - "code": event.data["code"], - "message": event.data["message"], - } - error = True - - self.handle_event(event_type, data_to_send) - if error: - self._tts_done.set() - self.handle_finished() - - async def run_pipeline( - self, - device_id: str, - conversation_id: str | None, - flags: int = 0, - audio_settings: VoiceAssistantAudioSettings | None = None, - wake_word_phrase: str | None = None, - ) -> None: - """Run the Voice Assistant pipeline.""" - if audio_settings is None or audio_settings.volume_multiplier == 0: - audio_settings = VoiceAssistantAudioSettings() - - if ( - self.device_info.voice_assistant_feature_flags_compat( - self.entry_data.api_version - ) - & VoiceAssistantFeature.SPEAKER - ): - tts_audio_output = "wav" - else: - tts_audio_output = "mp3" - - _LOGGER.debug("Starting pipeline") - if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD: - start_stage = PipelineStage.WAKE_WORD - else: - start_stage = PipelineStage.STT - try: - await async_pipeline_from_audio_stream( - self.hass, - context=self.context, - event_callback=self._event_callback, - stt_metadata=stt.SpeechMetadata( - language="", # set in async_pipeline_from_audio_stream - format=stt.AudioFormats.WAV, - codec=stt.AudioCodecs.PCM, - bit_rate=stt.AudioBitRates.BITRATE_16, - sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, - channel=stt.AudioChannels.CHANNEL_MONO, - ), - stt_stream=self._iterate_packets(), - pipeline_id=pipeline_select.get_chosen_pipeline( - self.hass, DOMAIN, self.device_info.mac_address - ), - conversation_id=conversation_id, - device_id=device_id, - tts_audio_output=tts_audio_output, - start_stage=start_stage, - wake_word_settings=WakeWordSettings(timeout=5), - wake_word_phrase=wake_word_phrase, - audio_settings=AudioSettings( - noise_suppression_level=audio_settings.noise_suppression_level, - auto_gain_dbfs=audio_settings.auto_gain, - volume_multiplier=audio_settings.volume_multiplier, - is_vad_enabled=bool(flags & VoiceAssistantCommandFlag.USE_VAD), - silence_seconds=VadSensitivity.to_seconds( - pipeline_select.get_vad_sensitivity( - self.hass, DOMAIN, self.device_info.mac_address - ) - ), - ), - ) - - # Block until TTS is done sending - await self._tts_done.wait() - - _LOGGER.debug("Pipeline finished") - except PipelineNotFound as e: - self.handle_event( - VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, - { - "code": e.code, - "message": e.message, - }, - ) - _LOGGER.warning("Pipeline not found") - except WakeWordDetectionAborted: - pass # Wake word detection was aborted and `handle_finished` is enough. - except WakeWordDetectionError as e: - self.handle_event( - VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, - { - "code": e.code, - "message": e.message, - }, - ) - finally: - self.handle_finished() - - async def _send_tts(self, media_id: str) -> None: - """Send TTS audio to device via UDP.""" - # Always send stream start/end events - self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}) - - try: - if not self.is_running: - return - - extension, data = await tts.async_get_media_source_audio( - self.hass, - media_id, - ) - - if extension != "wav": - raise ValueError(f"Only WAV audio can be streamed, got {extension}") - - with io.BytesIO(data) as wav_io: - with wave.open(wav_io, "rb") as wav_file: - sample_rate = wav_file.getframerate() - sample_width = wav_file.getsampwidth() - sample_channels = wav_file.getnchannels() - - if ( - (sample_rate != 16000) - or (sample_width != 2) - or (sample_channels != 1) - ): - raise ValueError( - "Expected rate/width/channels as 16000/2/1," - " got {sample_rate}/{sample_width}/{sample_channels}}" - ) - - audio_bytes = wav_file.readframes(wav_file.getnframes()) - - audio_bytes_size = len(audio_bytes) - - _LOGGER.debug("Sending %d bytes of audio", audio_bytes_size) - - bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8 - sample_offset = 0 - samples_left = audio_bytes_size // bytes_per_sample - - while (samples_left > 0) and self.is_running: - bytes_offset = sample_offset * bytes_per_sample - chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024] - samples_in_chunk = len(chunk) // bytes_per_sample - samples_left -= samples_in_chunk - - self.send_audio_bytes(chunk) - await asyncio.sleep( - samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9 - ) - - sample_offset += samples_in_chunk - finally: - self.handle_event( - VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {} - ) - self._tts_task = None - self._tts_done.set() - - def send_audio_bytes(self, data: bytes) -> None: - """Send bytes to the device.""" - raise NotImplementedError - - def stop(self) -> None: - """Stop the pipeline.""" - self.queue.put_nowait(b"") - - -class VoiceAssistantUDPPipeline(asyncio.DatagramProtocol, VoiceAssistantPipeline): - """Receive UDP packets and forward them to the voice assistant.""" - - transport: asyncio.DatagramTransport | None = None - remote_addr: tuple[str, int] | None = None - - async def start_server(self) -> int: - """Start accepting connections.""" - - def accept_connection() -> VoiceAssistantUDPPipeline: - """Accept connection.""" - if self.started: - raise RuntimeError("Can only start once") - if self.stop_requested: - raise RuntimeError("No longer accepting connections") - - self.started = True - return self - - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.setblocking(False) - - sock.bind(("", UDP_PORT)) - - await asyncio.get_running_loop().create_datagram_endpoint( - accept_connection, sock=sock - ) - - return cast(int, sock.getsockname()[1]) - - @callback - def connection_made(self, transport: asyncio.BaseTransport) -> None: - """Store transport for later use.""" - self.transport = cast(asyncio.DatagramTransport, transport) - - @callback - def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: - """Handle incoming UDP packet.""" - if not self.is_running: - return - if self.remote_addr is None: - self.remote_addr = addr - self.queue.put_nowait(data) - - def error_received(self, exc: Exception) -> None: - """Handle when a send or receive operation raises an OSError. - - (Other than BlockingIOError or InterruptedError.) - """ - _LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc) - self.handle_finished() - - @callback - def stop(self) -> None: - """Stop the receiver.""" - super().stop() - self.close() - - def close(self) -> None: - """Close the receiver.""" - self.started = False - self.stop_requested = True - - if self.transport is not None: - self.transport.close() - - def send_audio_bytes(self, data: bytes) -> None: - """Send bytes to the device via UDP.""" - if self.transport is None: - _LOGGER.error("No transport to send audio to") - return - self.transport.sendto(data, self.remote_addr) - - -class VoiceAssistantAPIPipeline(VoiceAssistantPipeline): - """Send audio to the voice assistant via the API.""" - - def __init__( - self, - hass: HomeAssistant, - entry_data: RuntimeEntryData, - handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], - handle_finished: Callable[[], None], - api_client: APIClient, - ) -> None: - """Initialize the pipeline.""" - super().__init__(hass, entry_data, handle_event, handle_finished) - self.api_client = api_client - self.started = True - - def send_audio_bytes(self, data: bytes) -> None: - """Send bytes to the device via the API.""" - self.api_client.send_voice_assistant_audio(data) - - @callback - def receive_audio_bytes(self, data: bytes) -> None: - """Receive audio bytes from the device.""" - if not self.is_running: - return - self.queue.put_nowait(data) - - @callback - def stop(self) -> None: - """Stop the pipeline.""" - super().stop() - - self.started = False - self.stop_requested = True - - -def handle_timer_event( - api_client: APIClient, event_type: TimerEventType, timer_info: TimerInfo -) -> None: - """Handle timer events.""" - try: - native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type) - except KeyError: - _LOGGER.debug("Received unknown timer event type: %s", event_type) - return - - api_client.send_voice_assistant_timer_event( - native_event_type, - timer_info.id, - timer_info.name, - timer_info.created_seconds, - timer_info.seconds_left, - timer_info.is_active, - ) diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index b3966875a31..af68df89360 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -20,7 +20,6 @@ from aioesphomeapi import ( ReconnectLogic, UserService, VoiceAssistantAudioSettings, - VoiceAssistantEventType, VoiceAssistantFeature, ) import pytest @@ -34,11 +33,6 @@ from homeassistant.components.esphome.const import ( DEFAULT_NEW_CONFIG_ALLOW_ALLOW_SERVICE_CALLS, DOMAIN, ) -from homeassistant.components.esphome.entry_data import RuntimeEntryData -from homeassistant.components.esphome.voice_assistant import ( - VoiceAssistantAPIPipeline, - VoiceAssistantUDPPipeline, -) from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component @@ -625,57 +619,3 @@ async def mock_esphome_device( ) return _mock_device - - -@pytest.fixture -def mock_voice_assistant_api_pipeline() -> VoiceAssistantAPIPipeline: - """Return the API Pipeline factory.""" - mock_pipeline = Mock(spec=VoiceAssistantAPIPipeline) - - def mock_constructor( - hass: HomeAssistant, - entry_data: RuntimeEntryData, - handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], - handle_finished: Callable[[], None], - api_client: APIClient, - ): - """Fake the constructor.""" - mock_pipeline.hass = hass - mock_pipeline.entry_data = entry_data - mock_pipeline.handle_event = handle_event - mock_pipeline.handle_finished = handle_finished - mock_pipeline.api_client = api_client - return mock_pipeline - - mock_pipeline.side_effect = mock_constructor - with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantAPIPipeline", - new=mock_pipeline, - ): - yield mock_pipeline - - -@pytest.fixture -def mock_voice_assistant_udp_pipeline() -> VoiceAssistantUDPPipeline: - """Return the API Pipeline factory.""" - mock_pipeline = Mock(spec=VoiceAssistantUDPPipeline) - - def mock_constructor( - hass: HomeAssistant, - entry_data: RuntimeEntryData, - handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], - handle_finished: Callable[[], None], - ): - """Fake the constructor.""" - mock_pipeline.hass = hass - mock_pipeline.entry_data = entry_data - mock_pipeline.handle_event = handle_event - mock_pipeline.handle_finished = handle_finished - return mock_pipeline - - mock_pipeline.side_effect = mock_constructor - with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPPipeline", - new=mock_pipeline, - ): - yield mock_pipeline diff --git a/tests/components/esphome/test_assist_satellite.py b/tests/components/esphome/test_assist_satellite.py new file mode 100644 index 00000000000..f024ca3b078 --- /dev/null +++ b/tests/components/esphome/test_assist_satellite.py @@ -0,0 +1,822 @@ +"""Test ESPHome voice assistant server.""" + +import asyncio +from collections.abc import Awaitable, Callable +import io +import socket +from unittest.mock import ANY, Mock, patch +import wave + +from aioesphomeapi import ( + APIClient, + EntityInfo, + EntityState, + UserService, + VoiceAssistantAudioSettings, + VoiceAssistantCommandFlag, + VoiceAssistantEventType, + VoiceAssistantFeature, + VoiceAssistantTimerEventType, +) +import pytest + +from homeassistant.components import assist_satellite +from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType +from homeassistant.components.assist_satellite.entity import ( + AssistSatelliteEntity, + AssistSatelliteState, +) +from homeassistant.components.esphome import DOMAIN +from homeassistant.components.esphome.assist_satellite import ( + EsphomeAssistSatellite, + VoiceAssistantUDPServer, +) +from homeassistant.const import Platform +from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er, intent as intent_helper +import homeassistant.helpers.device_registry as dr +from homeassistant.helpers.entity_component import EntityComponent + +from .conftest import MockESPHomeDevice + + +def get_satellite_entity( + hass: HomeAssistant, mac_address: str +) -> EsphomeAssistSatellite | None: + """Get the satellite entity for a device.""" + ent_reg = er.async_get(hass) + satellite_entity_id = ent_reg.async_get_entity_id( + Platform.ASSIST_SATELLITE, DOMAIN, f"{mac_address}-assist_satellite" + ) + if satellite_entity_id is None: + return None + + component: EntityComponent[AssistSatelliteEntity] = hass.data[ + assist_satellite.DOMAIN + ] + if (entity := component.get_entity(satellite_entity_id)) is not None: + assert isinstance(entity, EsphomeAssistSatellite) + return entity + + return None + + +@pytest.fixture +def mock_wav() -> bytes: + """Return test WAV audio.""" + with io.BytesIO() as wav_io: + with wave.open(wav_io, "wb") as wav_file: + wav_file.setframerate(16000) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + wav_file.writeframes(b"test-wav") + + return wav_io.getvalue() + + +async def test_no_satellite_without_voice_assistant( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test that an assist satellite entity is not created if a voice assistant is not present.""" + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={}, + ) + await hass.async_block_till_done() + + # No satellite entity should be created + assert get_satellite_entity(hass, mock_device.device_info.mac_address) is None + + +async def test_pipeline_api_audio( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], + mock_wav: bytes, +) -> None: + """Test a complete pipeline run with API audio (over the TCP connection).""" + conversation_id = "test-conversation-id" + media_url = "http://test.url" + media_id = "test-media-id" + + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.SPEAKER + | VoiceAssistantFeature.API_AUDIO + }, + ) + await hass.async_block_till_done() + dev = device_registry.async_get_device( + connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} + ) + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + + # Block TTS streaming until we're ready. + # This makes it easier to verify the order of pipeline events. + stream_tts_audio_ready = asyncio.Event() + original_stream_tts_audio = satellite._stream_tts_audio + + async def _stream_tts_audio(*args, **kwargs): + await stream_tts_audio_ready.wait() + await original_stream_tts_audio(*args, **kwargs) + + async def async_pipeline_from_audio_stream(*args, device_id, **kwargs): + assert device_id == dev.id + + stt_stream = kwargs["stt_stream"] + + chunks = [chunk async for chunk in stt_stream] + + # Verify test API audio + assert chunks == [b"test-mic"] + + event_callback = kwargs["event_callback"] + + # Test unknown event type + event_callback( + PipelineEvent( + type="unknown-event", + data={}, + ) + ) + + mock_client.send_voice_assistant_event.assert_not_called() + + # Test error event + event_callback( + PipelineEvent( + type=PipelineEventType.ERROR, + data={"code": "test-error-code", "message": "test-error-message"}, + ) + ) + + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, + {"code": "test-error-code", "message": "test-error-message"}, + ) + + # Wake word + assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD + + event_callback( + PipelineEvent( + type=PipelineEventType.WAKE_WORD_START, + data={ + "entity_id": "test-wake-word-entity-id", + "metadata": {}, + "timeout": 0, + }, + ) + ) + + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START, + {}, + ) + + # Test no wake word detected + event_callback( + PipelineEvent( + type=PipelineEventType.WAKE_WORD_END, data={"wake_word_output": {}} + ) + ) + + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, + {"code": "no_wake_word", "message": "No wake word detected"}, + ) + + # Correct wake word detection + event_callback( + PipelineEvent( + type=PipelineEventType.WAKE_WORD_END, + data={"wake_word_output": {"wake_word_phrase": "test-wake-word"}}, + ) + ) + + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END, + {}, + ) + + # STT + event_callback( + PipelineEvent( + type=PipelineEventType.STT_START, + data={"engine": "test-stt-engine", "metadata": {}}, + ) + ) + + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, + {}, + ) + assert satellite.state == AssistSatelliteState.LISTENING_COMMAND + + event_callback( + PipelineEvent( + type=PipelineEventType.STT_END, + data={"stt_output": {"text": "test-stt-text"}}, + ) + ) + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, + {"text": "test-stt-text"}, + ) + + # Intent + event_callback( + PipelineEvent( + type=PipelineEventType.INTENT_START, + data={ + "engine": "test-intent-engine", + "language": hass.config.language, + "intent_input": "test-intent-text", + "conversation_id": conversation_id, + "device_id": device_id, + }, + ) + ) + + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, + {}, + ) + assert satellite.state == AssistSatelliteState.PROCESSING + + event_callback( + PipelineEvent( + type=PipelineEventType.INTENT_END, + data={"intent_output": {"conversation_id": conversation_id}}, + ) + ) + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, + {"conversation_id": conversation_id}, + ) + + # TTS + event_callback( + PipelineEvent( + type=PipelineEventType.TTS_START, + data={ + "engine": "test-stt-engine", + "language": hass.config.language, + "voice": "test-voice", + "tts_input": "test-tts-text", + }, + ) + ) + + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, + {"text": "test-tts-text"}, + ) + assert satellite.state == AssistSatelliteState.RESPONDING + + # Should return mock_wav audio + event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={"tts_output": {"url": media_url, "media_id": media_id}}, + ) + ) + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END, + {"url": media_url}, + ) + + event_callback(PipelineEvent(type=PipelineEventType.RUN_END)) + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END, + {}, + ) + + # Allow TTS streaming to proceed + stream_tts_audio_ready.set() + + pipeline_finished = asyncio.Event() + original_handle_pipeline_finished = satellite.handle_pipeline_finished + + def handle_pipeline_finished(): + original_handle_pipeline_finished() + pipeline_finished.set() + + async def async_get_media_source_audio( + hass: HomeAssistant, + media_source_id: str, + ) -> tuple[str, bytes]: + return ("wav", mock_wav) + + tts_finished = asyncio.Event() + original_tts_response_finished = satellite.tts_response_finished + + def tts_response_finished(): + original_tts_response_finished() + tts_finished.set() + + with ( + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + patch( + "homeassistant.components.tts.async_get_media_source_audio", + new=async_get_media_source_audio, + ), + patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished), + patch.object(satellite, "_stream_tts_audio", _stream_tts_audio), + patch.object(satellite, "tts_response_finished", tts_response_finished), + ): + # Should be cleared at pipeline start + satellite._audio_queue.put_nowait(b"leftover-data") + + # Should be cancelled at pipeline start + mock_tts_streaming_task = Mock() + satellite._tts_streaming_task = mock_tts_streaming_task + + async with asyncio.timeout(1): + await satellite.handle_pipeline_start( + conversation_id=conversation_id, + flags=VoiceAssistantCommandFlag.USE_WAKE_WORD, + audio_settings=VoiceAssistantAudioSettings(), + wake_word_phrase="", + ) + mock_tts_streaming_task.cancel.assert_called_once() + await satellite.handle_audio(b"test-mic") + await satellite.handle_pipeline_stop() + await pipeline_finished.wait() + + await tts_finished.wait() + + # Verify TTS streaming events. + # These are definitely the last two events because we blocked TTS streaming + # until after RUN_END above. + assert mock_client.send_voice_assistant_event.call_args_list[-2].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, + {}, + ) + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, + {}, + ) + + # Verify TTS WAV audio chunk came through + mock_client.send_voice_assistant_audio.assert_called_once_with(b"test-wav") + + +@pytest.mark.usefixtures("socket_enabled") +async def test_pipeline_udp_audio( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], + mock_wav: bytes, +) -> None: + """Test a complete pipeline run with legacy UDP audio. + + This test is not as comprehensive as test_pipeline_api_audio since we're + mainly focused on the UDP server. + """ + conversation_id = "test-conversation-id" + media_url = "http://test.url" + media_id = "test-media-id" + + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.SPEAKER + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + + mic_audio_event = asyncio.Event() + + async def async_pipeline_from_audio_stream(*args, device_id, **kwargs): + stt_stream = kwargs["stt_stream"] + + chunks = [] + async for chunk in stt_stream: + chunks.append(chunk) + mic_audio_event.set() + + # Verify test UDP audio + assert chunks == [b"test-mic"] + + event_callback = kwargs["event_callback"] + + # STT + event_callback( + PipelineEvent( + type=PipelineEventType.STT_START, + data={"engine": "test-stt-engine", "metadata": {}}, + ) + ) + + event_callback( + PipelineEvent( + type=PipelineEventType.STT_END, + data={"stt_output": {"text": "test-stt-text"}}, + ) + ) + + # Intent + event_callback( + PipelineEvent( + type=PipelineEventType.INTENT_START, + data={ + "engine": "test-intent-engine", + "language": hass.config.language, + "intent_input": "test-intent-text", + "conversation_id": conversation_id, + "device_id": device_id, + }, + ) + ) + + event_callback( + PipelineEvent( + type=PipelineEventType.INTENT_END, + data={"intent_output": {"conversation_id": conversation_id}}, + ) + ) + + # TTS + event_callback( + PipelineEvent( + type=PipelineEventType.TTS_START, + data={ + "engine": "test-stt-engine", + "language": hass.config.language, + "voice": "test-voice", + "tts_input": "test-tts-text", + }, + ) + ) + + # Should return mock_wav audio + event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={"tts_output": {"url": media_url, "media_id": media_id}}, + ) + ) + + event_callback(PipelineEvent(type=PipelineEventType.RUN_END)) + + pipeline_finished = asyncio.Event() + original_handle_pipeline_finished = satellite.handle_pipeline_finished + + def handle_pipeline_finished(): + original_handle_pipeline_finished() + pipeline_finished.set() + + async def async_get_media_source_audio( + hass: HomeAssistant, + media_source_id: str, + ) -> tuple[str, bytes]: + return ("wav", mock_wav) + + tts_finished = asyncio.Event() + original_tts_response_finished = satellite.tts_response_finished + + def tts_response_finished(): + original_tts_response_finished() + tts_finished.set() + + class TestProtocol(asyncio.DatagramProtocol): + def __init__(self) -> None: + self.transport = None + self.data_received: list[bytes] = [] + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data: bytes, addr): + self.data_received.append(data) + + with ( + patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ), + patch( + "homeassistant.components.tts.async_get_media_source_audio", + new=async_get_media_source_audio, + ), + patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished), + patch.object(satellite, "tts_response_finished", tts_response_finished), + ): + async with asyncio.timeout(1): + port = await satellite.handle_pipeline_start( + conversation_id=conversation_id, + flags=VoiceAssistantCommandFlag(0), # stt + audio_settings=VoiceAssistantAudioSettings(), + wake_word_phrase="", + ) + assert (port is not None) and (port > 0) + + ( + transport, + protocol, + ) = await asyncio.get_running_loop().create_datagram_endpoint( + TestProtocol, remote_addr=("127.0.0.1", port) + ) + assert isinstance(protocol, TestProtocol) + + # Send audio over UDP + transport.sendto(b"test-mic") + + # Wait for audio chunk to be delivered + await mic_audio_event.wait() + + await satellite.handle_pipeline_stop() + await pipeline_finished.wait() + + await tts_finished.wait() + + # Verify TTS audio (from UDP) + assert protocol.data_received == [b"test-wav"] + + # Check that UDP server was stopped + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setblocking(False) + sock.bind(("", port)) # will fail if UDP server is still running + sock.close() + + +async def test_udp_errors() -> None: + """Test UDP protocol error conditions.""" + audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() + protocol = VoiceAssistantUDPServer(audio_queue) + + protocol.datagram_received(b"test", ("", 0)) + assert audio_queue.qsize() == 1 + assert (await audio_queue.get()) == b"test" + + # None will stop the pipeline + protocol.error_received(RuntimeError()) + assert audio_queue.qsize() == 1 + assert (await audio_queue.get()) is None + + # No transport + assert protocol.transport is None + protocol.send_audio_bytes(b"test") + + # No remote address + protocol.transport = Mock() + protocol.remote_addr = None + protocol.send_audio_bytes(b"test") + protocol.transport.sendto.assert_not_called() + + +async def test_timer_events( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test that injecting timer events results in the correct api client calls.""" + + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.TIMERS + }, + ) + await hass.async_block_till_done() + dev = device_registry.async_get_device( + connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} + ) + + total_seconds = (1 * 60 * 60) + (2 * 60) + 3 + await intent_helper.async_handle( + hass, + "test", + intent_helper.INTENT_START_TIMER, + { + "name": {"value": "test timer"}, + "hours": {"value": 1}, + "minutes": {"value": 2}, + "seconds": {"value": 3}, + }, + device_id=dev.id, + ) + + mock_client.send_voice_assistant_timer_event.assert_called_with( + VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED, + ANY, + "test timer", + total_seconds, + total_seconds, + True, + ) + + # Increase timer beyond original time and check total_seconds has increased + mock_client.send_voice_assistant_timer_event.reset_mock() + + total_seconds += 5 * 60 + await intent_helper.async_handle( + hass, + "test", + intent_helper.INTENT_INCREASE_TIMER, + { + "name": {"value": "test timer"}, + "minutes": {"value": 5}, + }, + device_id=dev.id, + ) + + mock_client.send_voice_assistant_timer_event.assert_called_with( + VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED, + ANY, + "test timer", + total_seconds, + ANY, + True, + ) + + +async def test_unknown_timer_event( + hass: HomeAssistant, + device_registry: dr.DeviceRegistry, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test that unknown (new) timer event types do not result in api calls.""" + + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.TIMERS + }, + ) + await hass.async_block_till_done() + assert mock_device.entry.unique_id is not None + dev = device_registry.async_get_device( + connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} + ) + assert dev is not None + + with patch( + "homeassistant.components.esphome.assist_satellite._TIMER_EVENT_TYPES.from_hass", + side_effect=KeyError, + ): + await intent_helper.async_handle( + hass, + "test", + intent_helper.INTENT_START_TIMER, + { + "name": {"value": "test timer"}, + "hours": {"value": 1}, + "minutes": {"value": 2}, + "seconds": {"value": 3}, + }, + device_id=dev.id, + ) + + mock_client.send_voice_assistant_timer_event.assert_not_called() + + +async def test_streaming_tts_errors( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], + mock_wav: bytes, +) -> None: + """Test error conditions for _stream_tts_audio function.""" + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + + # Should not stream if not running + satellite._is_running = False + await satellite._stream_tts_audio("test-media-id") + mock_client.send_voice_assistant_audio.assert_not_called() + satellite._is_running = True + + # Should only stream WAV + async def get_mp3( + hass: HomeAssistant, + media_source_id: str, + ) -> tuple[str, bytes]: + return ("mp3", b"") + + with patch( + "homeassistant.components.tts.async_get_media_source_audio", new=get_mp3 + ): + await satellite._stream_tts_audio("test-media-id") + mock_client.send_voice_assistant_audio.assert_not_called() + + # Needs to be the correct sample rate, etc. + async def get_bad_wav( + hass: HomeAssistant, + media_source_id: str, + ) -> tuple[str, bytes]: + with io.BytesIO() as wav_io: + with wave.open(wav_io, "wb") as wav_file: + wav_file.setframerate(48000) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + wav_file.writeframes(b"test-wav") + + return ("wav", wav_io.getvalue()) + + with patch( + "homeassistant.components.tts.async_get_media_source_audio", new=get_bad_wav + ): + await satellite._stream_tts_audio("test-media-id") + mock_client.send_voice_assistant_audio.assert_not_called() + + # Check that TTS_STREAM_* events still get sent after cancel + media_fetched = asyncio.Event() + + async def get_slow_wav( + hass: HomeAssistant, + media_source_id: str, + ) -> tuple[str, bytes]: + media_fetched.set() + await asyncio.sleep(1) + return ("wav", mock_wav) + + mock_client.send_voice_assistant_event.reset_mock() + with patch( + "homeassistant.components.tts.async_get_media_source_audio", new=get_slow_wav + ): + task = asyncio.create_task(satellite._stream_tts_audio("test-media-id")) + async with asyncio.timeout(1): + # Wait for media to be fetched + await media_fetched.wait() + + # Cancel task + task.cancel() + await task + + # No audio should have gone out + mock_client.send_voice_assistant_audio.assert_not_called() + assert len(mock_client.send_voice_assistant_event.call_args_list) == 2 + + # The TTS_STREAM_* events should have gone out + assert mock_client.send_voice_assistant_event.call_args_list[-2].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, + {}, + ) + assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, + {}, + ) diff --git a/tests/components/esphome/test_manager.py b/tests/components/esphome/test_manager.py index a14c83bf265..4b322c8744e 100644 --- a/tests/components/esphome/test_manager.py +++ b/tests/components/esphome/test_manager.py @@ -2,7 +2,7 @@ import asyncio from collections.abc import Awaitable, Callable -from unittest.mock import AsyncMock, call, patch +from unittest.mock import AsyncMock, call from aioesphomeapi import ( APIClient, @@ -17,7 +17,6 @@ from aioesphomeapi import ( UserService, UserServiceArg, UserServiceArgType, - VoiceAssistantFeature, ) import pytest @@ -29,10 +28,6 @@ from homeassistant.components.esphome.const import ( DOMAIN, STABLE_BLE_VERSION_STR, ) -from homeassistant.components.esphome.voice_assistant import ( - VoiceAssistantAPIPipeline, - VoiceAssistantUDPPipeline, -) from homeassistant.const import ( CONF_HOST, CONF_PASSWORD, @@ -44,7 +39,7 @@ from homeassistant.data_entry_flow import FlowResultType from homeassistant.helpers import device_registry as dr, issue_registry as ir from homeassistant.setup import async_setup_component -from .conftest import _ONE_SECOND, MockESPHomeDevice +from .conftest import MockESPHomeDevice from tests.common import MockConfigEntry, async_capture_events, async_mock_service @@ -1214,102 +1209,3 @@ async def test_entry_missing_unique_id( await mock_esphome_device(mock_client=mock_client, mock_storage=True) await hass.async_block_till_done() assert entry.unique_id == "11:22:33:44:55:aa" - - -async def test_manager_voice_assistant_handlers_api( - hass: HomeAssistant, - mock_client: APIClient, - mock_esphome_device: Callable[ - [APIClient, list[EntityInfo], list[UserService], list[EntityState]], - Awaitable[MockESPHomeDevice], - ], - caplog: pytest.LogCaptureFixture, - mock_voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test the handlers are correctly executed in manager.py.""" - - device: MockESPHomeDevice = await mock_esphome_device( - mock_client=mock_client, - entity_info=[], - user_service=[], - states=[], - device_info={ - "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT - | VoiceAssistantFeature.API_AUDIO - }, - ) - - await hass.async_block_till_done() - - with ( - patch( - "homeassistant.components.esphome.manager.VoiceAssistantAPIPipeline", - new=mock_voice_assistant_api_pipeline, - ), - ): - port: int | None = await device.mock_voice_assistant_handle_start( - "", 0, None, None - ) - - assert port == 0 - - port: int | None = await device.mock_voice_assistant_handle_start( - "", 0, None, None - ) - - assert "Previous Voice assistant pipeline was not stopped" in caplog.text - - await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND)) - - mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_called_with( - bytes(_ONE_SECOND) - ) - - mock_voice_assistant_api_pipeline.receive_audio_bytes.reset_mock() - - await device.mock_voice_assistant_handle_stop() - mock_voice_assistant_api_pipeline.handle_finished() - - await device.mock_voice_assistant_handle_audio(bytes(_ONE_SECOND)) - - mock_voice_assistant_api_pipeline.receive_audio_bytes.assert_not_called() - - -async def test_manager_voice_assistant_handlers_udp( - hass: HomeAssistant, - mock_client: APIClient, - mock_esphome_device: Callable[ - [APIClient, list[EntityInfo], list[UserService], list[EntityState]], - Awaitable[MockESPHomeDevice], - ], - mock_voice_assistant_udp_pipeline: VoiceAssistantUDPPipeline, -) -> None: - """Test the handlers are correctly executed in manager.py.""" - - device: MockESPHomeDevice = await mock_esphome_device( - mock_client=mock_client, - entity_info=[], - user_service=[], - states=[], - device_info={ - "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT - }, - ) - - await hass.async_block_till_done() - - with ( - patch( - "homeassistant.components.esphome.manager.VoiceAssistantUDPPipeline", - new=mock_voice_assistant_udp_pipeline, - ), - ): - await device.mock_voice_assistant_handle_start("", 0, None, None) - - mock_voice_assistant_udp_pipeline.run_pipeline.assert_called() - - await device.mock_voice_assistant_handle_stop() - mock_voice_assistant_udp_pipeline.handle_finished() - - mock_voice_assistant_udp_pipeline.stop.assert_called() - mock_voice_assistant_udp_pipeline.close.assert_called() diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py deleted file mode 100644 index eafc0243dc6..00000000000 --- a/tests/components/esphome/test_voice_assistant.py +++ /dev/null @@ -1,964 +0,0 @@ -"""Test ESPHome voice assistant server.""" - -import asyncio -from collections.abc import Awaitable, Callable -import io -import socket -from unittest.mock import ANY, Mock, patch -import wave - -from aioesphomeapi import ( - APIClient, - EntityInfo, - EntityState, - UserService, - VoiceAssistantEventType, - VoiceAssistantFeature, - VoiceAssistantTimerEventType, -) -import pytest - -from homeassistant.components.assist_pipeline import ( - PipelineEvent, - PipelineEventType, - PipelineStage, -) -from homeassistant.components.assist_pipeline.error import ( - PipelineNotFound, - WakeWordDetectionAborted, - WakeWordDetectionError, -) -from homeassistant.components.esphome import DomainData -from homeassistant.components.esphome.voice_assistant import ( - VoiceAssistantAPIPipeline, - VoiceAssistantUDPPipeline, -) -from homeassistant.core import HomeAssistant -from homeassistant.helpers import intent as intent_helper -import homeassistant.helpers.device_registry as dr - -from .conftest import _ONE_SECOND, MockESPHomeDevice - -_TEST_INPUT_TEXT = "This is an input test" -_TEST_OUTPUT_TEXT = "This is an output test" -_TEST_OUTPUT_URL = "output.mp3" -_TEST_MEDIA_ID = "12345" - - -@pytest.fixture -def voice_assistant_udp_pipeline( - hass: HomeAssistant, -) -> VoiceAssistantUDPPipeline: - """Return the UDP pipeline factory.""" - - def _voice_assistant_udp_server(entry): - entry_data = DomainData.get(hass).get_entry_data(entry) - - server: VoiceAssistantUDPPipeline = None - - def handle_finished(): - nonlocal server - assert server is not None - server.close() - - server = VoiceAssistantUDPPipeline(hass, entry_data, Mock(), handle_finished) - return server # noqa: RET504 - - return _voice_assistant_udp_server - - -@pytest.fixture -def voice_assistant_api_pipeline( - hass: HomeAssistant, - mock_client, - mock_voice_assistant_api_entry, -) -> VoiceAssistantAPIPipeline: - """Return the API Pipeline factory.""" - entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_api_entry) - return VoiceAssistantAPIPipeline(hass, entry_data, Mock(), Mock(), mock_client) - - -@pytest.fixture -def voice_assistant_udp_pipeline_v1( - voice_assistant_udp_pipeline, - mock_voice_assistant_v1_entry, -) -> VoiceAssistantUDPPipeline: - """Return the UDP pipeline.""" - return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v1_entry) - - -@pytest.fixture -def voice_assistant_udp_pipeline_v2( - voice_assistant_udp_pipeline, - mock_voice_assistant_v2_entry, -) -> VoiceAssistantUDPPipeline: - """Return the UDP pipeline.""" - return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v2_entry) - - -@pytest.fixture -def mock_wav() -> bytes: - """Return one second of empty WAV audio.""" - with io.BytesIO() as wav_io: - with wave.open(wav_io, "wb") as wav_file: - wav_file.setframerate(16000) - wav_file.setsampwidth(2) - wav_file.setnchannels(1) - wav_file.writeframes(bytes(_ONE_SECOND)) - - return wav_io.getvalue() - - -async def test_pipeline_events( - hass: HomeAssistant, - voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, -) -> None: - """Test that the pipeline function is called.""" - - async def async_pipeline_from_audio_stream(*args, device_id, **kwargs): - assert device_id == "mock-device-id" - - event_callback = kwargs["event_callback"] - - event_callback( - PipelineEvent( - type=PipelineEventType.WAKE_WORD_END, - data={"wake_word_output": {}}, - ) - ) - - # Fake events - event_callback( - PipelineEvent( - type=PipelineEventType.STT_START, - data={}, - ) - ) - - event_callback( - PipelineEvent( - type=PipelineEventType.STT_END, - data={"stt_output": {"text": _TEST_INPUT_TEXT}}, - ) - ) - - event_callback( - PipelineEvent( - type=PipelineEventType.TTS_START, - data={"tts_input": _TEST_OUTPUT_TEXT}, - ) - ) - - event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={"tts_output": {"url": _TEST_OUTPUT_URL}}, - ) - ) - - def handle_event( - event_type: VoiceAssistantEventType, data: dict[str, str] | None - ) -> None: - if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: - assert data is not None - assert data["text"] == _TEST_INPUT_TEXT - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: - assert data is not None - assert data["text"] == _TEST_OUTPUT_TEXT - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: - assert data is not None - assert data["url"] == _TEST_OUTPUT_URL - elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: - assert data is None - - voice_assistant_udp_pipeline_v1.handle_event = handle_event - - with patch( - "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ): - voice_assistant_udp_pipeline_v1.transport = Mock() - - await voice_assistant_udp_pipeline_v1.run_pipeline( - device_id="mock-device-id", conversation_id=None - ) - - -@pytest.mark.usefixtures("socket_enabled") -async def test_udp_server( - unused_udp_port_factory: Callable[[], int], - voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, -) -> None: - """Test the UDP server runs and queues incoming data.""" - port_to_use = unused_udp_port_factory() - - with patch( - "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use - ): - port = await voice_assistant_udp_pipeline_v1.start_server() - assert port == port_to_use - - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - - assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0 - sock.sendto(b"test", ("127.0.0.1", port)) - - # Give the socket some time to send/receive the data - async with asyncio.timeout(1): - while voice_assistant_udp_pipeline_v1.queue.qsize() == 0: - await asyncio.sleep(0.1) - - assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 - - voice_assistant_udp_pipeline_v1.stop() - voice_assistant_udp_pipeline_v1.close() - - assert voice_assistant_udp_pipeline_v1.transport.is_closing() - - -async def test_udp_server_queue( - hass: HomeAssistant, - voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, -) -> None: - """Test the UDP server queues incoming data.""" - - voice_assistant_udp_pipeline_v1.started = True - - assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0 - - voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0)) - assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 - - voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0)) - assert voice_assistant_udp_pipeline_v1.queue.qsize() == 2 - - async for data in voice_assistant_udp_pipeline_v1._iterate_packets(): - assert data == bytes(1024) - break - assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 # One message removed - - voice_assistant_udp_pipeline_v1.stop() - assert ( - voice_assistant_udp_pipeline_v1.queue.qsize() == 2 - ) # An empty message added by stop - - voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0)) - assert ( - voice_assistant_udp_pipeline_v1.queue.qsize() == 2 - ) # No new messages added after stop - - voice_assistant_udp_pipeline_v1.close() - - # Stopping the UDP server should cause _iterate_packets to break out - # immediately without yielding any data. - has_data = False - async for _data in voice_assistant_udp_pipeline_v1._iterate_packets(): - has_data = True - - assert not has_data, "Server was stopped" - - -async def test_api_pipeline_queue( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test the API pipeline queues incoming data.""" - - voice_assistant_api_pipeline.started = True - - assert voice_assistant_api_pipeline.queue.qsize() == 0 - - voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024)) - assert voice_assistant_api_pipeline.queue.qsize() == 1 - - voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024)) - assert voice_assistant_api_pipeline.queue.qsize() == 2 - - async for data in voice_assistant_api_pipeline._iterate_packets(): - assert data == bytes(1024) - break - assert voice_assistant_api_pipeline.queue.qsize() == 1 # One message removed - - voice_assistant_api_pipeline.stop() - assert ( - voice_assistant_api_pipeline.queue.qsize() == 2 - ) # An empty message added by stop - - voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024)) - assert ( - voice_assistant_api_pipeline.queue.qsize() == 2 - ) # No new messages added after stop - - # Stopping the API Pipeline should cause _iterate_packets to break out - # immediately without yielding any data. - has_data = False - async for _data in voice_assistant_api_pipeline._iterate_packets(): - has_data = True - - assert not has_data, "Pipeline was stopped" - - -async def test_error_calls_handle_finished( - hass: HomeAssistant, - voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, -) -> None: - """Test that the handle_finished callback is called when an error occurs.""" - voice_assistant_udp_pipeline_v1.handle_finished = Mock() - - voice_assistant_udp_pipeline_v1.error_received(Exception()) - - voice_assistant_udp_pipeline_v1.handle_finished.assert_called() - - -@pytest.mark.usefixtures("socket_enabled") -async def test_udp_server_multiple( - unused_udp_port_factory: Callable[[], int], - voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, -) -> None: - """Test that the UDP server raises an error if started twice.""" - with patch( - "homeassistant.components.esphome.voice_assistant.UDP_PORT", - new=unused_udp_port_factory(), - ): - await voice_assistant_udp_pipeline_v1.start_server() - - with ( - patch( - "homeassistant.components.esphome.voice_assistant.UDP_PORT", - new=unused_udp_port_factory(), - ), - pytest.raises(RuntimeError), - ): - await voice_assistant_udp_pipeline_v1.start_server() - - -@pytest.mark.usefixtures("socket_enabled") -async def test_udp_server_after_stopped( - unused_udp_port_factory: Callable[[], int], - voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, -) -> None: - """Test that the UDP server raises an error if started after stopped.""" - voice_assistant_udp_pipeline_v1.close() - with ( - patch( - "homeassistant.components.esphome.voice_assistant.UDP_PORT", - new=unused_udp_port_factory(), - ), - pytest.raises(RuntimeError), - ): - await voice_assistant_udp_pipeline_v1.start_server() - - -async def test_events_converted_correctly( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test the pipeline events produce the correct data to send to the device.""" - - with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts", - ): - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.STT_START, - data={}, - ) - ) - - voice_assistant_api_pipeline.handle_event.assert_called_with( - VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, None - ) - - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.STT_END, - data={"stt_output": {"text": "text"}}, - ) - ) - - voice_assistant_api_pipeline.handle_event.assert_called_with( - VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": "text"} - ) - - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.INTENT_START, - data={}, - ) - ) - - voice_assistant_api_pipeline.handle_event.assert_called_with( - VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, None - ) - - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.INTENT_END, - data={ - "intent_output": { - "conversation_id": "conversation-id", - } - }, - ) - ) - - voice_assistant_api_pipeline.handle_event.assert_called_with( - VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, - {"conversation_id": "conversation-id"}, - ) - - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_START, - data={"tts_input": "text"}, - ) - ) - - voice_assistant_api_pipeline.handle_event.assert_called_with( - VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": "text"} - ) - - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={"tts_output": {"url": "url", "media_id": "media-id"}}, - ) - ) - - voice_assistant_api_pipeline.handle_event.assert_called_with( - VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END, {"url": "url"} - ) - - -async def test_unknown_event_type( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test the API pipeline does not call handle_event for unknown events.""" - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type="unknown-event", - data={}, - ) - ) - - assert not voice_assistant_api_pipeline.handle_event.called - - -async def test_error_event_type( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test the API pipeline calls event handler with error.""" - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.ERROR, - data={"code": "code", "message": "message"}, - ) - ) - - voice_assistant_api_pipeline.handle_event.assert_called_with( - VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, - {"code": "code", "message": "message"}, - ) - - -async def test_send_tts_not_called( - hass: HomeAssistant, - voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, -) -> None: - """Test the UDP server with a v1 device does not call _send_tts.""" - with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts" - ) as mock_send_tts: - voice_assistant_udp_pipeline_v1._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} - }, - ) - ) - - mock_send_tts.assert_not_called() - - -async def test_send_tts_called_udp( - hass: HomeAssistant, - voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, -) -> None: - """Test the UDP server with a v2 device calls _send_tts.""" - with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts" - ) as mock_send_tts: - voice_assistant_udp_pipeline_v2._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} - }, - ) - ) - - mock_send_tts.assert_called_with(_TEST_MEDIA_ID) - - -async def test_send_tts_called_api( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test the API pipeline calls _send_tts.""" - with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts" - ) as mock_send_tts: - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} - }, - ) - ) - - mock_send_tts.assert_called_with(_TEST_MEDIA_ID) - - -async def test_send_tts_not_called_when_empty( - hass: HomeAssistant, - voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, - voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test the pipelines do not call _send_tts when the output is empty.""" - with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts" - ) as mock_send_tts: - voice_assistant_udp_pipeline_v1._event_callback( - PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}}) - ) - - mock_send_tts.assert_not_called() - - voice_assistant_udp_pipeline_v2._event_callback( - PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}}) - ) - - mock_send_tts.assert_not_called() - - voice_assistant_api_pipeline._event_callback( - PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}}) - ) - - mock_send_tts.assert_not_called() - - -async def test_send_tts_udp( - hass: HomeAssistant, - voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, - mock_wav: bytes, -) -> None: - """Test the UDP server calls sendto to transmit audio data to device.""" - with patch( - "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("wav", mock_wav), - ): - voice_assistant_udp_pipeline_v2.started = True - voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport) - with patch.object( - voice_assistant_udp_pipeline_v2.transport, "is_closing", return_value=False - ): - voice_assistant_udp_pipeline_v2._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": { - "media_id": _TEST_MEDIA_ID, - "url": _TEST_OUTPUT_URL, - } - }, - ) - ) - - await voice_assistant_udp_pipeline_v2._tts_done.wait() - - voice_assistant_udp_pipeline_v2.transport.sendto.assert_called() - - -async def test_send_tts_api( - hass: HomeAssistant, - mock_client: APIClient, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, - mock_wav: bytes, -) -> None: - """Test the API pipeline calls cli.send_voice_assistant_audio to transmit audio data to device.""" - with patch( - "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("wav", mock_wav), - ): - voice_assistant_api_pipeline.started = True - - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": { - "media_id": _TEST_MEDIA_ID, - "url": _TEST_OUTPUT_URL, - } - }, - ) - ) - - await voice_assistant_api_pipeline._tts_done.wait() - - mock_client.send_voice_assistant_audio.assert_called() - - -async def test_send_tts_wrong_sample_rate( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test that only 16000Hz audio will be streamed.""" - with io.BytesIO() as wav_io: - with wave.open(wav_io, "wb") as wav_file: - wav_file.setframerate(22050) - wav_file.setsampwidth(2) - wav_file.setnchannels(1) - wav_file.writeframes(bytes(_ONE_SECOND)) - - wav_bytes = wav_io.getvalue() - with patch( - "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("wav", wav_bytes), - ): - voice_assistant_api_pipeline.started = True - voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport) - - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} - }, - ) - ) - - assert voice_assistant_api_pipeline._tts_task is not None - with pytest.raises(ValueError): - await voice_assistant_api_pipeline._tts_task - - -async def test_send_tts_wrong_format( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test that only WAV audio will be streamed.""" - with ( - patch( - "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("raw", bytes(1024)), - ), - ): - voice_assistant_api_pipeline.started = True - voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport) - - voice_assistant_api_pipeline._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} - }, - ) - ) - - assert voice_assistant_api_pipeline._tts_task is not None - with pytest.raises(ValueError): - await voice_assistant_api_pipeline._tts_task - - -async def test_send_tts_not_started( - hass: HomeAssistant, - voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, - mock_wav: bytes, -) -> None: - """Test the UDP server does not call sendto when not started.""" - with patch( - "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("wav", mock_wav), - ): - voice_assistant_udp_pipeline_v2.started = False - voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport) - - voice_assistant_udp_pipeline_v2._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} - }, - ) - ) - - await voice_assistant_udp_pipeline_v2._tts_done.wait() - - voice_assistant_udp_pipeline_v2.transport.sendto.assert_not_called() - - -async def test_send_tts_transport_none( - hass: HomeAssistant, - voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, - mock_wav: bytes, - caplog: pytest.LogCaptureFixture, -) -> None: - """Test the UDP server does not call sendto when transport is None.""" - with patch( - "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", - return_value=("wav", mock_wav), - ): - voice_assistant_udp_pipeline_v2.started = True - voice_assistant_udp_pipeline_v2.transport = None - - voice_assistant_udp_pipeline_v2._event_callback( - PipelineEvent( - type=PipelineEventType.TTS_END, - data={ - "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} - }, - ) - ) - await voice_assistant_udp_pipeline_v2._tts_done.wait() - - assert "No transport to send audio to" in caplog.text - - -async def test_wake_word( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test that the pipeline is set to start with Wake word.""" - - async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs): - assert start_stage == PipelineStage.WAKE_WORD - - with ( - patch( - "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ), - patch("asyncio.Event.wait"), # TTS wait event - ): - await voice_assistant_api_pipeline.run_pipeline( - device_id="mock-device-id", - conversation_id=None, - flags=2, - ) - - -async def test_wake_word_exception( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test that the pipeline is set to start with Wake word.""" - - async def async_pipeline_from_audio_stream(*args, **kwargs): - raise WakeWordDetectionError("pipeline-not-found", "Pipeline not found") - - with patch( - "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ): - - def handle_event( - event_type: VoiceAssistantEventType, data: dict[str, str] | None - ) -> None: - if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: - assert data is not None - assert data["code"] == "pipeline-not-found" - assert data["message"] == "Pipeline not found" - - voice_assistant_api_pipeline.handle_event = handle_event - - await voice_assistant_api_pipeline.run_pipeline( - device_id="mock-device-id", - conversation_id=None, - flags=2, - ) - - -async def test_wake_word_abort_exception( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test that the pipeline is set to start with Wake word.""" - - async def async_pipeline_from_audio_stream(*args, **kwargs): - raise WakeWordDetectionAborted - - with ( - patch( - "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ), - patch.object(voice_assistant_api_pipeline, "handle_event") as mock_handle_event, - ): - await voice_assistant_api_pipeline.run_pipeline( - device_id="mock-device-id", - conversation_id=None, - flags=2, - ) - - mock_handle_event.assert_not_called() - - -async def test_timer_events( - hass: HomeAssistant, - device_registry: dr.DeviceRegistry, - mock_client: APIClient, - mock_esphome_device: Callable[ - [APIClient, list[EntityInfo], list[UserService], list[EntityState]], - Awaitable[MockESPHomeDevice], - ], -) -> None: - """Test that injecting timer events results in the correct api client calls.""" - - mock_device: MockESPHomeDevice = await mock_esphome_device( - mock_client=mock_client, - entity_info=[], - user_service=[], - states=[], - device_info={ - "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT - | VoiceAssistantFeature.TIMERS - }, - ) - await hass.async_block_till_done() - dev = device_registry.async_get_device( - connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} - ) - - total_seconds = (1 * 60 * 60) + (2 * 60) + 3 - await intent_helper.async_handle( - hass, - "test", - intent_helper.INTENT_START_TIMER, - { - "name": {"value": "test timer"}, - "hours": {"value": 1}, - "minutes": {"value": 2}, - "seconds": {"value": 3}, - }, - device_id=dev.id, - ) - - mock_client.send_voice_assistant_timer_event.assert_called_with( - VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED, - ANY, - "test timer", - total_seconds, - total_seconds, - True, - ) - - # Increase timer beyond original time and check total_seconds has increased - mock_client.send_voice_assistant_timer_event.reset_mock() - - total_seconds += 5 * 60 - await intent_helper.async_handle( - hass, - "test", - intent_helper.INTENT_INCREASE_TIMER, - { - "name": {"value": "test timer"}, - "minutes": {"value": 5}, - }, - device_id=dev.id, - ) - - mock_client.send_voice_assistant_timer_event.assert_called_with( - VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED, - ANY, - "test timer", - total_seconds, - ANY, - True, - ) - - -async def test_unknown_timer_event( - hass: HomeAssistant, - device_registry: dr.DeviceRegistry, - mock_client: APIClient, - mock_esphome_device: Callable[ - [APIClient, list[EntityInfo], list[UserService], list[EntityState]], - Awaitable[MockESPHomeDevice], - ], -) -> None: - """Test that unknown (new) timer event types do not result in api calls.""" - - mock_device: MockESPHomeDevice = await mock_esphome_device( - mock_client=mock_client, - entity_info=[], - user_service=[], - states=[], - device_info={ - "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT - | VoiceAssistantFeature.TIMERS - }, - ) - await hass.async_block_till_done() - dev = device_registry.async_get_device( - connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} - ) - - with patch( - "homeassistant.components.esphome.voice_assistant._TIMER_EVENT_TYPES.from_hass", - side_effect=KeyError, - ): - await intent_helper.async_handle( - hass, - "test", - intent_helper.INTENT_START_TIMER, - { - "name": {"value": "test timer"}, - "hours": {"value": 1}, - "minutes": {"value": 2}, - "seconds": {"value": 3}, - }, - device_id=dev.id, - ) - - mock_client.send_voice_assistant_timer_event.assert_not_called() - - -async def test_invalid_pipeline_id( - hass: HomeAssistant, - voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, -) -> None: - """Test that the pipeline is set to start with Wake word.""" - - invalid_pipeline_id = "invalid-pipeline-id" - - async def async_pipeline_from_audio_stream(*args, **kwargs): - raise PipelineNotFound( - "pipeline_not_found", f"Pipeline {invalid_pipeline_id} not found" - ) - - with patch( - "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ): - - def handle_event( - event_type: VoiceAssistantEventType, data: dict[str, str] | None - ) -> None: - if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: - assert data is not None - assert data["code"] == "pipeline_not_found" - assert data["message"] == f"Pipeline {invalid_pipeline_id} not found" - - voice_assistant_api_pipeline.handle_event = handle_event - - await voice_assistant_api_pipeline.run_pipeline( - device_id="mock-device-id", - conversation_id=None, - flags=2, - )