diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index 02a5054dd1d..5f94cf5c291 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -22,6 +22,7 @@ from aioesphomeapi import ( RequiresEncryptionAPIError, UserService, UserServiceArgType, + VoiceAssistantEventType, ) from awesomeversion import AwesomeVersion import voluptuous as vol @@ -64,6 +65,7 @@ from .domain_data import DomainData # Import config flow so that it's added to the registry from .entry_data import RuntimeEntryData from .enum_mapper import EsphomeEnumMapper +from .voice_assistant import VoiceAssistantUDPServer CONF_DEVICE_NAME = "device_name" CONF_NOISE_PSK = "noise_psk" @@ -284,6 +286,39 @@ async def async_setup_entry( # noqa: C901 _send_home_assistant_state(entity_id, attribute, hass.states.get(entity_id)) ) + voice_assistant_udp_server: VoiceAssistantUDPServer | None = None + + def handle_pipeline_event( + event_type: VoiceAssistantEventType, data: dict[str, str] | None + ) -> None: + """Handle a voice assistant pipeline event.""" + cli.send_voice_assistant_event(event_type, data) + + async def handle_pipeline_start() -> int | None: + """Start a voice assistant pipeline.""" + nonlocal voice_assistant_udp_server + + if voice_assistant_udp_server is not None: + return None + + voice_assistant_udp_server = VoiceAssistantUDPServer(hass) + port = await voice_assistant_udp_server.start_server() + + hass.async_create_background_task( + voice_assistant_udp_server.run_pipeline(handle_pipeline_event), + "esphome.voice_assistant_udp_server.run_pipeline", + ) + + return port + + async def handle_pipeline_stop() -> None: + """Stop a voice assistant pipeline.""" + nonlocal voice_assistant_udp_server + + if voice_assistant_udp_server is not None: + voice_assistant_udp_server.stop() + voice_assistant_udp_server = None + async def on_connect() -> None: """Subscribe to states and list entities on successful API login.""" nonlocal device_id @@ -328,6 +363,14 @@ async def async_setup_entry( # noqa: C901 await cli.subscribe_service_calls(async_on_service_call) await cli.subscribe_home_assistant_states(async_on_state_subscription) + if device_info.voice_assistant_version: + entry_data.disconnect_callbacks.append( + await cli.subscribe_voice_assistant( + handle_pipeline_start, + handle_pipeline_stop, + ) + ) + hass.async_create_task(entry_data.async_save_to_store()) except APIConnectionError as err: _LOGGER.warning("Error getting initial data for %s: %s", host, err) diff --git a/homeassistant/components/esphome/manifest.json b/homeassistant/components/esphome/manifest.json index bf3e269221e..06c629e2d44 100644 --- a/homeassistant/components/esphome/manifest.json +++ b/homeassistant/components/esphome/manifest.json @@ -1,7 +1,7 @@ { "domain": "esphome", "name": "ESPHome", - "after_dependencies": ["zeroconf", "tag"], + "after_dependencies": ["zeroconf", "tag", "assist_pipeline"], "codeowners": ["@OttoWinter", "@jesserockz"], "config_flow": true, "dependencies": ["bluetooth"], @@ -14,6 +14,6 @@ "integration_type": "device", "iot_class": "local_push", "loggers": ["aioesphomeapi", "noiseprotocol"], - "requirements": ["aioesphomeapi==13.6.1", "esphome-dashboard-api==1.2.3"], + "requirements": ["aioesphomeapi==13.7.0", "esphome-dashboard-api==1.2.3"], "zeroconf": ["_esphomelib._tcp.local."] } diff --git a/homeassistant/components/esphome/voice_assistant.py b/homeassistant/components/esphome/voice_assistant.py new file mode 100644 index 00000000000..6d3a5a78d65 --- /dev/null +++ b/homeassistant/components/esphome/voice_assistant.py @@ -0,0 +1,164 @@ +"""ESPHome voice assistant support.""" +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterable, Callable +import logging +import socket +from typing import cast + +from aioesphomeapi import VoiceAssistantEventType + +from homeassistant.components import stt +from homeassistant.components.assist_pipeline import ( + PipelineEvent, + PipelineEventType, + async_pipeline_from_audio_stream, +) +from homeassistant.components.media_player import async_process_play_media_url +from homeassistant.core import HomeAssistant, callback + +from .enum_mapper import EsphomeEnumMapper + +_LOGGER = logging.getLogger(__name__) + +UDP_PORT = 0 # Set to 0 to let the OS pick a free random port + +_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[ + VoiceAssistantEventType, PipelineEventType +] = EsphomeEnumMapper( + { + VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR, + VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START, + VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END, + VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START, + VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END, + VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START, + VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END, + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START, + VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END, + } +) + + +class VoiceAssistantUDPServer(asyncio.DatagramProtocol): + """Receive UDP packets and forward them to the voice assistant.""" + + started = False + queue: asyncio.Queue[bytes] | None = None + transport: asyncio.DatagramTransport | None = None + + def __init__(self, hass: HomeAssistant) -> None: + """Initialize UDP receiver.""" + self.hass = hass + self.queue = asyncio.Queue() + + async def start_server(self) -> int: + """Start accepting connections.""" + + def accept_connection() -> VoiceAssistantUDPServer: + """Accept connection.""" + if self.started: + raise RuntimeError("Can only start once") + if self.queue is None: + raise RuntimeError("No longer accepting connections") + + self.started = True + return self + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setblocking(False) + + sock.bind(("", UDP_PORT)) + + await asyncio.get_running_loop().create_datagram_endpoint( + accept_connection, sock=sock + ) + + return cast(int, sock.getsockname()[1]) + + @callback + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """Store transport for later use.""" + self.transport = cast(asyncio.DatagramTransport, transport) + + @callback + def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: + """Handle incoming UDP packet.""" + if self.queue is not None: + self.queue.put_nowait(data) + + def error_received(self, exc: Exception) -> None: + """Handle when a send or receive operation raises an OSError. + + (Other than BlockingIOError or InterruptedError.) + """ + _LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc) + + @callback + def stop(self) -> None: + """Stop the receiver.""" + if self.queue is not None: + self.queue.put_nowait(b"") + self.queue = None + if self.transport is not None: + self.transport.close() + + async def _iterate_packets(self) -> AsyncIterable[bytes]: + """Iterate over incoming packets.""" + if self.queue is None: + raise RuntimeError("Already stopped") + + while data := await self.queue.get(): + yield data + + async def run_pipeline( + self, + handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], + ) -> None: + """Run the Voice Assistant pipeline.""" + + @callback + def handle_pipeline_event(event: PipelineEvent) -> None: + """Handle pipeline events.""" + + try: + event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type) + except KeyError: + _LOGGER.warning("Received unknown pipeline event type: %s", event.type) + return + + data_to_send = None + if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: + assert event.data is not None + data_to_send = {"text": event.data["stt_output"]["text"]} + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: + assert event.data is not None + data_to_send = {"text": event.data["tts_input"]} + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: + assert event.data is not None + path = event.data["tts_output"]["url"] + url = async_process_play_media_url(self.hass, path) + data_to_send = {"url": url} + elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: + assert event.data is not None + data_to_send = { + "code": event.data["code"], + "message": event.data["message"], + } + + handle_event(event_type, data_to_send) + + await async_pipeline_from_audio_stream( + self.hass, + event_callback=handle_pipeline_event, + stt_metadata=stt.SpeechMetadata( + language="", + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ), + stt_stream=self._iterate_packets(), + ) diff --git a/requirements_all.txt b/requirements_all.txt index 06e0826835e..7abb024fb4b 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -156,7 +156,7 @@ aioecowitt==2023.01.0 aioemonitor==1.0.5 # homeassistant.components.esphome -aioesphomeapi==13.6.1 +aioesphomeapi==13.7.0 # homeassistant.components.flo aioflo==2021.11.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 0dfbd53b1b1..7b649186650 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -146,7 +146,7 @@ aioecowitt==2023.01.0 aioemonitor==1.0.5 # homeassistant.components.esphome -aioesphomeapi==13.6.1 +aioesphomeapi==13.7.0 # homeassistant.components.flo aioflo==2021.11.0 diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py new file mode 100644 index 00000000000..ee6f4f7289f --- /dev/null +++ b/tests/components/esphome/test_voice_assistant.py @@ -0,0 +1,141 @@ +"""Test ESPHome voice assistant server.""" + +import asyncio +import socket +from unittest.mock import Mock, patch + +import async_timeout +import pytest + +from homeassistant.components import assist_pipeline, esphome +from homeassistant.core import HomeAssistant + +_TEST_INPUT_TEXT = "This is an input test" +_TEST_OUTPUT_TEXT = "This is an output test" +_TEST_OUTPUT_URL = "output.mp3" + + +async def test_pipeline_events(hass: HomeAssistant) -> None: + """Test that the pipeline function is called.""" + + async def async_pipeline_from_audio_stream(*args, **kwargs): + event_callback = kwargs["event_callback"] + + # Fake events + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.STT_START, + data={}, + ) + ) + + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.STT_END, + data={"stt_output": {"text": _TEST_INPUT_TEXT}}, + ) + ) + + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.TTS_START, + data={"tts_input": _TEST_OUTPUT_TEXT}, + ) + ) + + event_callback( + assist_pipeline.PipelineEvent( + type=assist_pipeline.PipelineEventType.TTS_END, + data={"tts_output": {"url": _TEST_OUTPUT_URL}}, + ) + ) + + def handle_event( + event_type: esphome.VoiceAssistantEventType, data: dict[str, str] | None + ) -> None: + if event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: + assert data is not None + assert data["text"] == _TEST_INPUT_TEXT + elif event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: + assert data is not None + assert data["text"] == _TEST_OUTPUT_TEXT + elif event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: + assert data is not None + assert data["url"] == _TEST_OUTPUT_URL + + with patch( + "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, + ): + server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) + server.transport = Mock() + + await server.run_pipeline(handle_event) + + +async def test_udp_server( + hass: HomeAssistant, + socket_enabled, + unused_udp_port_factory, +) -> None: + """Test the UDP server runs and queues incoming data.""" + port_to_use = unused_udp_port_factory() + + server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) + with patch( + "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use + ): + port = await server.start_server() + assert port == port_to_use + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + assert server.queue.qsize() == 0 + sock.sendto(b"test", ("127.0.0.1", port)) + + # Give the socket some time to send/receive the data + async with async_timeout.timeout(1): + while server.queue.qsize() == 0: + await asyncio.sleep(0.1) + + assert server.queue.qsize() == 1 + + server.stop() + + assert server.transport.is_closing() + + +async def test_udp_server_multiple( + hass: HomeAssistant, + socket_enabled, + unused_udp_port_factory, +) -> None: + """Test that the UDP server raises an error if started twice.""" + server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) + with patch( + "homeassistant.components.esphome.voice_assistant.UDP_PORT", + new=unused_udp_port_factory(), + ): + await server.start_server() + + with patch( + "homeassistant.components.esphome.voice_assistant.UDP_PORT", + new=unused_udp_port_factory(), + ), pytest.raises(RuntimeError): + pass + await server.start_server() + + +async def test_udp_server_after_stopped( + hass: HomeAssistant, + socket_enabled, + unused_udp_port_factory, +) -> None: + """Test that the UDP server raises an error if started after stopped.""" + server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) + server.stop() + with patch( + "homeassistant.components.esphome.voice_assistant.UDP_PORT", + new=unused_udp_port_factory(), + ), pytest.raises(RuntimeError): + await server.start_server()