mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 09:47:52 +00:00
ESPHome voice assistant (#90691)
* Add ESPHome push-to-talk * Send pipeline events to device * Bump aioesphomeapi to 13.7.0 * Log error instead of print * Rename variable * lint * Rename * Fix type and cast * Move event data manipulation into voice_assistant callback Process full url * Add a test? * Remove import * More tests * Update import * Update manifest * fix tests * Ugh --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
1c0b2630da
commit
0ddccb26fa
@ -22,6 +22,7 @@ from aioesphomeapi import (
|
||||
RequiresEncryptionAPIError,
|
||||
UserService,
|
||||
UserServiceArgType,
|
||||
VoiceAssistantEventType,
|
||||
)
|
||||
from awesomeversion import AwesomeVersion
|
||||
import voluptuous as vol
|
||||
@ -64,6 +65,7 @@ from .domain_data import DomainData
|
||||
# Import config flow so that it's added to the registry
|
||||
from .entry_data import RuntimeEntryData
|
||||
from .enum_mapper import EsphomeEnumMapper
|
||||
from .voice_assistant import VoiceAssistantUDPServer
|
||||
|
||||
CONF_DEVICE_NAME = "device_name"
|
||||
CONF_NOISE_PSK = "noise_psk"
|
||||
@ -284,6 +286,39 @@ async def async_setup_entry( # noqa: C901
|
||||
_send_home_assistant_state(entity_id, attribute, hass.states.get(entity_id))
|
||||
)
|
||||
|
||||
voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
|
||||
|
||||
def handle_pipeline_event(
|
||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||
) -> None:
|
||||
"""Handle a voice assistant pipeline event."""
|
||||
cli.send_voice_assistant_event(event_type, data)
|
||||
|
||||
async def handle_pipeline_start() -> int | None:
|
||||
"""Start a voice assistant pipeline."""
|
||||
nonlocal voice_assistant_udp_server
|
||||
|
||||
if voice_assistant_udp_server is not None:
|
||||
return None
|
||||
|
||||
voice_assistant_udp_server = VoiceAssistantUDPServer(hass)
|
||||
port = await voice_assistant_udp_server.start_server()
|
||||
|
||||
hass.async_create_background_task(
|
||||
voice_assistant_udp_server.run_pipeline(handle_pipeline_event),
|
||||
"esphome.voice_assistant_udp_server.run_pipeline",
|
||||
)
|
||||
|
||||
return port
|
||||
|
||||
async def handle_pipeline_stop() -> None:
|
||||
"""Stop a voice assistant pipeline."""
|
||||
nonlocal voice_assistant_udp_server
|
||||
|
||||
if voice_assistant_udp_server is not None:
|
||||
voice_assistant_udp_server.stop()
|
||||
voice_assistant_udp_server = None
|
||||
|
||||
async def on_connect() -> None:
|
||||
"""Subscribe to states and list entities on successful API login."""
|
||||
nonlocal device_id
|
||||
@ -328,6 +363,14 @@ async def async_setup_entry( # noqa: C901
|
||||
await cli.subscribe_service_calls(async_on_service_call)
|
||||
await cli.subscribe_home_assistant_states(async_on_state_subscription)
|
||||
|
||||
if device_info.voice_assistant_version:
|
||||
entry_data.disconnect_callbacks.append(
|
||||
await cli.subscribe_voice_assistant(
|
||||
handle_pipeline_start,
|
||||
handle_pipeline_stop,
|
||||
)
|
||||
)
|
||||
|
||||
hass.async_create_task(entry_data.async_save_to_store())
|
||||
except APIConnectionError as err:
|
||||
_LOGGER.warning("Error getting initial data for %s: %s", host, err)
|
||||
|
@ -1,7 +1,7 @@
|
||||
{
|
||||
"domain": "esphome",
|
||||
"name": "ESPHome",
|
||||
"after_dependencies": ["zeroconf", "tag"],
|
||||
"after_dependencies": ["zeroconf", "tag", "assist_pipeline"],
|
||||
"codeowners": ["@OttoWinter", "@jesserockz"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["bluetooth"],
|
||||
@ -14,6 +14,6 @@
|
||||
"integration_type": "device",
|
||||
"iot_class": "local_push",
|
||||
"loggers": ["aioesphomeapi", "noiseprotocol"],
|
||||
"requirements": ["aioesphomeapi==13.6.1", "esphome-dashboard-api==1.2.3"],
|
||||
"requirements": ["aioesphomeapi==13.7.0", "esphome-dashboard-api==1.2.3"],
|
||||
"zeroconf": ["_esphomelib._tcp.local."]
|
||||
}
|
||||
|
164
homeassistant/components/esphome/voice_assistant.py
Normal file
164
homeassistant/components/esphome/voice_assistant.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""ESPHome voice assistant support."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable, Callable
|
||||
import logging
|
||||
import socket
|
||||
from typing import cast
|
||||
|
||||
from aioesphomeapi import VoiceAssistantEventType
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.assist_pipeline import (
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
async_pipeline_from_audio_stream,
|
||||
)
|
||||
from homeassistant.components.media_player import async_process_play_media_url
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
|
||||
from .enum_mapper import EsphomeEnumMapper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
UDP_PORT = 0 # Set to 0 to let the OS pick a free random port
|
||||
|
||||
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
|
||||
VoiceAssistantEventType, PipelineEventType
|
||||
] = EsphomeEnumMapper(
|
||||
{
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
"""Receive UDP packets and forward them to the voice assistant."""
|
||||
|
||||
started = False
|
||||
queue: asyncio.Queue[bytes] | None = None
|
||||
transport: asyncio.DatagramTransport | None = None
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize UDP receiver."""
|
||||
self.hass = hass
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
async def start_server(self) -> int:
|
||||
"""Start accepting connections."""
|
||||
|
||||
def accept_connection() -> VoiceAssistantUDPServer:
|
||||
"""Accept connection."""
|
||||
if self.started:
|
||||
raise RuntimeError("Can only start once")
|
||||
if self.queue is None:
|
||||
raise RuntimeError("No longer accepting connections")
|
||||
|
||||
self.started = True
|
||||
return self
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
|
||||
sock.bind(("", UDP_PORT))
|
||||
|
||||
await asyncio.get_running_loop().create_datagram_endpoint(
|
||||
accept_connection, sock=sock
|
||||
)
|
||||
|
||||
return cast(int, sock.getsockname()[1])
|
||||
|
||||
@callback
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Store transport for later use."""
|
||||
self.transport = cast(asyncio.DatagramTransport, transport)
|
||||
|
||||
@callback
|
||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""Handle incoming UDP packet."""
|
||||
if self.queue is not None:
|
||||
self.queue.put_nowait(data)
|
||||
|
||||
def error_received(self, exc: Exception) -> None:
|
||||
"""Handle when a send or receive operation raises an OSError.
|
||||
|
||||
(Other than BlockingIOError or InterruptedError.)
|
||||
"""
|
||||
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
|
||||
|
||||
@callback
|
||||
def stop(self) -> None:
|
||||
"""Stop the receiver."""
|
||||
if self.queue is not None:
|
||||
self.queue.put_nowait(b"")
|
||||
self.queue = None
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
async def _iterate_packets(self) -> AsyncIterable[bytes]:
|
||||
"""Iterate over incoming packets."""
|
||||
if self.queue is None:
|
||||
raise RuntimeError("Already stopped")
|
||||
|
||||
while data := await self.queue.get():
|
||||
yield data
|
||||
|
||||
async def run_pipeline(
|
||||
self,
|
||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
||||
) -> None:
|
||||
"""Run the Voice Assistant pipeline."""
|
||||
|
||||
@callback
|
||||
def handle_pipeline_event(event: PipelineEvent) -> None:
|
||||
"""Handle pipeline events."""
|
||||
|
||||
try:
|
||||
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
|
||||
except KeyError:
|
||||
_LOGGER.warning("Received unknown pipeline event type: %s", event.type)
|
||||
return
|
||||
|
||||
data_to_send = None
|
||||
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["stt_output"]["text"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
||||
assert event.data is not None
|
||||
data_to_send = {"text": event.data["tts_input"]}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||
assert event.data is not None
|
||||
path = event.data["tts_output"]["url"]
|
||||
url = async_process_play_media_url(self.hass, path)
|
||||
data_to_send = {"url": url}
|
||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
|
||||
assert event.data is not None
|
||||
data_to_send = {
|
||||
"code": event.data["code"],
|
||||
"message": event.data["message"],
|
||||
}
|
||||
|
||||
handle_event(event_type, data_to_send)
|
||||
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
event_callback=handle_pipeline_event,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=self._iterate_packets(),
|
||||
)
|
@ -156,7 +156,7 @@ aioecowitt==2023.01.0
|
||||
aioemonitor==1.0.5
|
||||
|
||||
# homeassistant.components.esphome
|
||||
aioesphomeapi==13.6.1
|
||||
aioesphomeapi==13.7.0
|
||||
|
||||
# homeassistant.components.flo
|
||||
aioflo==2021.11.0
|
||||
|
@ -146,7 +146,7 @@ aioecowitt==2023.01.0
|
||||
aioemonitor==1.0.5
|
||||
|
||||
# homeassistant.components.esphome
|
||||
aioesphomeapi==13.6.1
|
||||
aioesphomeapi==13.7.0
|
||||
|
||||
# homeassistant.components.flo
|
||||
aioflo==2021.11.0
|
||||
|
141
tests/components/esphome/test_voice_assistant.py
Normal file
141
tests/components/esphome/test_voice_assistant.py
Normal file
@ -0,0 +1,141 @@
|
||||
"""Test ESPHome voice assistant server."""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import async_timeout
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import assist_pipeline, esphome
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
_TEST_INPUT_TEXT = "This is an input test"
|
||||
_TEST_OUTPUT_TEXT = "This is an output test"
|
||||
_TEST_OUTPUT_URL = "output.mp3"
|
||||
|
||||
|
||||
async def test_pipeline_events(hass: HomeAssistant) -> None:
|
||||
"""Test that the pipeline function is called."""
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
event_callback = kwargs["event_callback"]
|
||||
|
||||
# Fake events
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_START,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.STT_END,
|
||||
data={"stt_output": {"text": _TEST_INPUT_TEXT}},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_START,
|
||||
data={"tts_input": _TEST_OUTPUT_TEXT},
|
||||
)
|
||||
)
|
||||
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"url": _TEST_OUTPUT_URL}},
|
||||
)
|
||||
)
|
||||
|
||||
def handle_event(
|
||||
event_type: esphome.VoiceAssistantEventType, data: dict[str, str] | None
|
||||
) -> None:
|
||||
if event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
|
||||
assert data is not None
|
||||
assert data["text"] == _TEST_INPUT_TEXT
|
||||
elif event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
|
||||
assert data is not None
|
||||
assert data["text"] == _TEST_OUTPUT_TEXT
|
||||
elif event_type == esphome.VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
|
||||
assert data is not None
|
||||
assert data["url"] == _TEST_OUTPUT_URL
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
|
||||
server.transport = Mock()
|
||||
|
||||
await server.run_pipeline(handle_event)
|
||||
|
||||
|
||||
async def test_udp_server(
|
||||
hass: HomeAssistant,
|
||||
socket_enabled,
|
||||
unused_udp_port_factory,
|
||||
) -> None:
|
||||
"""Test the UDP server runs and queues incoming data."""
|
||||
port_to_use = unused_udp_port_factory()
|
||||
|
||||
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
|
||||
):
|
||||
port = await server.start_server()
|
||||
assert port == port_to_use
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
assert server.queue.qsize() == 0
|
||||
sock.sendto(b"test", ("127.0.0.1", port))
|
||||
|
||||
# Give the socket some time to send/receive the data
|
||||
async with async_timeout.timeout(1):
|
||||
while server.queue.qsize() == 0:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert server.queue.qsize() == 1
|
||||
|
||||
server.stop()
|
||||
|
||||
assert server.transport.is_closing()
|
||||
|
||||
|
||||
async def test_udp_server_multiple(
|
||||
hass: HomeAssistant,
|
||||
socket_enabled,
|
||||
unused_udp_port_factory,
|
||||
) -> None:
|
||||
"""Test that the UDP server raises an error if started twice."""
|
||||
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
):
|
||||
await server.start_server()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
), pytest.raises(RuntimeError):
|
||||
pass
|
||||
await server.start_server()
|
||||
|
||||
|
||||
async def test_udp_server_after_stopped(
|
||||
hass: HomeAssistant,
|
||||
socket_enabled,
|
||||
unused_udp_port_factory,
|
||||
) -> None:
|
||||
"""Test that the UDP server raises an error if started after stopped."""
|
||||
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
|
||||
server.stop()
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
), pytest.raises(RuntimeError):
|
||||
await server.start_server()
|
Loading…
x
Reference in New Issue
Block a user