mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Voice assistant integration with pipelines (#89822)
* Initial commit * Add websocket test tool * Small tweak * Tiny cleanup * Make pipeline work with frontend branch * Add some more info to start event * Fixes * First voice assistant tests * Remove run_task * Clean up for PR * Add config_flow.py * Remove CLI tool * Simplify by removing stt/tts for now * Clean up and fix tests * More clean up and API changes * Add quality_scale * Remove data from run-finish * Use StrEnum backport --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
81c0382e4b
commit
e16f17f5a8
@ -1309,6 +1309,8 @@ build.json @home-assistant/supervisor
|
|||||||
/tests/components/vizio/ @raman325
|
/tests/components/vizio/ @raman325
|
||||||
/homeassistant/components/vlc_telnet/ @rodripf @MartinHjelmare
|
/homeassistant/components/vlc_telnet/ @rodripf @MartinHjelmare
|
||||||
/tests/components/vlc_telnet/ @rodripf @MartinHjelmare
|
/tests/components/vlc_telnet/ @rodripf @MartinHjelmare
|
||||||
|
/homeassistant/components/voice_assistant/ @balloob @synesthesiam
|
||||||
|
/tests/components/voice_assistant/ @balloob @synesthesiam
|
||||||
/homeassistant/components/volumio/ @OnFreund
|
/homeassistant/components/volumio/ @OnFreund
|
||||||
/tests/components/volumio/ @OnFreund
|
/tests/components/volumio/ @OnFreund
|
||||||
/homeassistant/components/volvooncall/ @molobrakos
|
/homeassistant/components/volvooncall/ @molobrakos
|
||||||
|
23
homeassistant/components/voice_assistant/__init__.py
Normal file
23
homeassistant/components/voice_assistant/__init__.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
"""The Voice Assistant integration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
|
from .const import DEFAULT_PIPELINE, DOMAIN
|
||||||
|
from .pipeline import Pipeline
|
||||||
|
from .websocket_api import async_register_websocket_api
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
|
"""Set up Voice Assistant integration."""
|
||||||
|
hass.data[DOMAIN] = {
|
||||||
|
DEFAULT_PIPELINE: Pipeline(
|
||||||
|
name=DEFAULT_PIPELINE,
|
||||||
|
language=None,
|
||||||
|
conversation_engine=None,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
async_register_websocket_api(hass)
|
||||||
|
|
||||||
|
return True
|
3
homeassistant/components/voice_assistant/const.py
Normal file
3
homeassistant/components/voice_assistant/const.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
"""Constants for the Voice Assistant integration."""
|
||||||
|
DOMAIN = "voice_assistant"
|
||||||
|
DEFAULT_PIPELINE = "default"
|
9
homeassistant/components/voice_assistant/manifest.json
Normal file
9
homeassistant/components/voice_assistant/manifest.json
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"domain": "voice_assistant",
|
||||||
|
"name": "Voice Assistant",
|
||||||
|
"codeowners": ["@balloob", "@synesthesiam"],
|
||||||
|
"dependencies": ["conversation"],
|
||||||
|
"documentation": "https://www.home-assistant.io/integrations/voice_assistant",
|
||||||
|
"iot_class": "local_push",
|
||||||
|
"quality_scale": "internal"
|
||||||
|
}
|
124
homeassistant/components/voice_assistant/pipeline.py
Normal file
124
homeassistant/components/voice_assistant/pipeline.py
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
"""Classes for voice assistant pipelines."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.backports.enum import StrEnum
|
||||||
|
from homeassistant.components import conversation
|
||||||
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
|
DEFAULT_TIMEOUT = 30 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineRequest:
|
||||||
|
"""Request to start a pipeline run."""
|
||||||
|
|
||||||
|
intent_input: str
|
||||||
|
conversation_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineEventType(StrEnum):
|
||||||
|
"""Event types emitted during a pipeline run."""
|
||||||
|
|
||||||
|
RUN_START = "run-start"
|
||||||
|
RUN_FINISH = "run-finish"
|
||||||
|
INTENT_START = "intent-start"
|
||||||
|
INTENT_FINISH = "intent-finish"
|
||||||
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineEvent:
|
||||||
|
"""Events emitted during a pipeline run."""
|
||||||
|
|
||||||
|
type: PipelineEventType
|
||||||
|
data: dict[str, Any] | None = None
|
||||||
|
timestamp: str = field(default_factory=lambda: utcnow().isoformat())
|
||||||
|
|
||||||
|
def as_dict(self) -> dict[str, Any]:
|
||||||
|
"""Return a dict representation of the event."""
|
||||||
|
return {
|
||||||
|
"type": self.type,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"data": self.data or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Pipeline:
|
||||||
|
"""A voice assistant pipeline."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
language: str | None
|
||||||
|
conversation_engine: str | None
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
context: Context,
|
||||||
|
request: PipelineRequest,
|
||||||
|
event_callback: Callable[[PipelineEvent], None],
|
||||||
|
timeout: int | float | None = DEFAULT_TIMEOUT,
|
||||||
|
) -> None:
|
||||||
|
"""Run a pipeline with an optional timeout."""
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._run(hass, context, request, event_callback), timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _run(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
context: Context,
|
||||||
|
request: PipelineRequest,
|
||||||
|
event_callback: Callable[[PipelineEvent], None],
|
||||||
|
) -> None:
|
||||||
|
"""Run a pipeline."""
|
||||||
|
language = self.language or hass.config.language
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.RUN_START,
|
||||||
|
{
|
||||||
|
"pipeline": self.name,
|
||||||
|
"language": language,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
intent_input = request.intent_input
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.INTENT_START,
|
||||||
|
{
|
||||||
|
"engine": self.conversation_engine or "default",
|
||||||
|
"intent_input": intent_input,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_result = await conversation.async_converse(
|
||||||
|
hass=hass,
|
||||||
|
text=intent_input,
|
||||||
|
conversation_id=request.conversation_id,
|
||||||
|
context=context,
|
||||||
|
language=language,
|
||||||
|
agent_id=self.conversation_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.INTENT_FINISH,
|
||||||
|
{"intent_output": conversation_result.as_dict()},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
PipelineEventType.RUN_FINISH,
|
||||||
|
)
|
||||||
|
)
|
67
homeassistant/components/voice_assistant/websocket_api.py
Normal file
67
homeassistant/components/voice_assistant/websocket_api.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
"""Voice Assistant Websocket API."""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import websocket_api
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .pipeline import DEFAULT_TIMEOUT, PipelineRequest
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_register_websocket_api(hass: HomeAssistant) -> None:
|
||||||
|
"""Register the websocket API."""
|
||||||
|
websocket_api.async_register_command(hass, websocket_run)
|
||||||
|
|
||||||
|
|
||||||
|
@websocket_api.websocket_command(
|
||||||
|
{
|
||||||
|
vol.Required("type"): "voice_assistant/run",
|
||||||
|
vol.Optional("pipeline", default="default"): str,
|
||||||
|
vol.Required("intent_input"): str,
|
||||||
|
vol.Optional("conversation_id"): vol.Any(str, None),
|
||||||
|
vol.Optional("timeout"): vol.Any(float, int),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@websocket_api.async_response
|
||||||
|
async def websocket_run(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
connection: websocket_api.ActiveConnection,
|
||||||
|
msg: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Run a pipeline."""
|
||||||
|
pipeline_id = msg["pipeline"]
|
||||||
|
pipeline = hass.data[DOMAIN].get(pipeline_id)
|
||||||
|
if pipeline is None:
|
||||||
|
connection.send_error(
|
||||||
|
msg["id"], "pipeline_not_found", f"Pipeline not found: {pipeline_id}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Run pipeline with a timeout.
|
||||||
|
# Events are sent over the websocket connection.
|
||||||
|
timeout = msg.get("timeout", DEFAULT_TIMEOUT)
|
||||||
|
run_task = hass.async_create_task(
|
||||||
|
pipeline.run(
|
||||||
|
hass,
|
||||||
|
connection.context(msg),
|
||||||
|
request=PipelineRequest(
|
||||||
|
intent_input=msg["intent_input"],
|
||||||
|
conversation_id=msg.get("conversation_id"),
|
||||||
|
),
|
||||||
|
event_callback=lambda event: connection.send_event(
|
||||||
|
msg["id"], event.as_dict()
|
||||||
|
),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cancel pipeline if user unsubscribes
|
||||||
|
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||||
|
|
||||||
|
connection.send_result(msg["id"])
|
||||||
|
|
||||||
|
# Task contains a timeout
|
||||||
|
await run_task
|
@ -65,6 +65,11 @@ class ActiveConnection:
|
|||||||
"""Send a result message."""
|
"""Send a result message."""
|
||||||
self.send_message(messages.result_message(msg_id, result))
|
self.send_message(messages.result_message(msg_id, result))
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def send_event(self, msg_id: int, event: Any | None = None) -> None:
|
||||||
|
"""Send a event message."""
|
||||||
|
self.send_message(messages.event_message(msg_id, event))
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def send_error(self, msg_id: int, code: str, message: str) -> None:
|
def send_error(self, msg_id: int, code: str, message: str) -> None:
|
||||||
"""Send a error message."""
|
"""Send a error message."""
|
||||||
|
@ -6068,6 +6068,12 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"voice_assistant": {
|
||||||
|
"name": "Voice Assistant",
|
||||||
|
"integration_type": "hub",
|
||||||
|
"config_flow": false,
|
||||||
|
"iot_class": "local_push"
|
||||||
|
},
|
||||||
"voicerss": {
|
"voicerss": {
|
||||||
"name": "VoiceRSS",
|
"name": "VoiceRSS",
|
||||||
"integration_type": "hub",
|
"integration_type": "hub",
|
||||||
|
1
tests/components/voice_assistant/__init__.py
Normal file
1
tests/components/voice_assistant/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
"""Tests for the Voice Assistant integration."""
|
152
tests/components/voice_assistant/test_websocket.py
Normal file
152
tests/components/voice_assistant/test_websocket.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def init_components(hass):
|
||||||
|
"""Initialize relevant components with empty configs."""
|
||||||
|
assert await async_setup_component(hass, "voice_assistant", {})
|
||||||
|
|
||||||
|
|
||||||
|
async def test_text_only_pipeline(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
) -> None:
|
||||||
|
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
await client.send_json(
|
||||||
|
{"id": 5, "type": "voice_assistant/run", "intent_input": "Are the lights on?"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"pipeline": "default",
|
||||||
|
"language": hass.config.language,
|
||||||
|
}
|
||||||
|
|
||||||
|
# intent
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"engine": "default",
|
||||||
|
"intent_input": "Are the lights on?",
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-finish"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"intent_output": {
|
||||||
|
"response": {
|
||||||
|
"speech": {
|
||||||
|
"plain": {
|
||||||
|
"speech": "Sorry, I couldn't understand that",
|
||||||
|
"extra_data": None,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"card": {},
|
||||||
|
"language": "en",
|
||||||
|
"response_type": "error",
|
||||||
|
"data": {"code": "no_intent_match"},
|
||||||
|
},
|
||||||
|
"conversation_id": None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# run finish
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-finish"
|
||||||
|
assert msg["event"]["data"] == {}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_timeout(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test partial pipeline run with conversation agent timeout."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
async def sleepy_converse(*args, **kwargs):
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.conversation.async_converse", new=sleepy_converse
|
||||||
|
):
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"intent_input": "Are the lights on?",
|
||||||
|
"timeout": 0.00001,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# run start
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "run-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"pipeline": "default",
|
||||||
|
"language": hass.config.language,
|
||||||
|
}
|
||||||
|
|
||||||
|
# intent
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["event"]["type"] == "intent-start"
|
||||||
|
assert msg["event"]["data"] == {
|
||||||
|
"engine": "default",
|
||||||
|
"intent_input": "Are the lights on?",
|
||||||
|
}
|
||||||
|
|
||||||
|
# timeout error
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"]["code"] == "timeout"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_timeout(
|
||||||
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
|
) -> None:
|
||||||
|
"""Test pipeline run with immediate timeout."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
async def sleepy_run(*args, **kwargs):
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.voice_assistant.pipeline.Pipeline._run",
|
||||||
|
new=sleepy_run,
|
||||||
|
):
|
||||||
|
await client.send_json(
|
||||||
|
{
|
||||||
|
"id": 5,
|
||||||
|
"type": "voice_assistant/run",
|
||||||
|
"intent_input": "Are the lights on?",
|
||||||
|
"timeout": 0.0001,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# result
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
|
# timeout error
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert msg["error"]["code"] == "timeout"
|
Loading…
x
Reference in New Issue
Block a user