mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Allow targeting conversation agent as pipeline (#119556)
* Allow targetting conversation agent as pipeline * Test that we can use a conversation entity as an assist pipeline * Add test for WS get --------- Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
parent
69ed730101
commit
154da1b18b
@ -259,6 +259,22 @@ async def async_create_default_pipeline(
|
|||||||
return await pipeline_store.async_create_item(pipeline_settings)
|
return await pipeline_store.async_create_item(pipeline_settings)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_get_pipeline_from_conversation_entity(
|
||||||
|
hass: HomeAssistant, entity_id: str
|
||||||
|
) -> Pipeline:
|
||||||
|
"""Get a pipeline by conversation entity ID."""
|
||||||
|
entity = hass.states.get(entity_id)
|
||||||
|
settings = _async_resolve_default_pipeline_settings(
|
||||||
|
hass,
|
||||||
|
pipeline_name=entity.name if entity else entity_id,
|
||||||
|
conversation_engine_id=entity_id,
|
||||||
|
)
|
||||||
|
settings["id"] = entity_id
|
||||||
|
|
||||||
|
return Pipeline.from_json(settings)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
|
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
|
||||||
"""Get a pipeline by id or the preferred pipeline."""
|
"""Get a pipeline by id or the preferred pipeline."""
|
||||||
@ -268,6 +284,9 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
|
|||||||
# A pipeline was not specified, use the preferred one
|
# A pipeline was not specified, use the preferred one
|
||||||
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
||||||
|
|
||||||
|
if pipeline_id.startswith("conversation."):
|
||||||
|
return _async_get_pipeline_from_conversation_entity(hass, pipeline_id)
|
||||||
|
|
||||||
pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
|
pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||||
|
|
||||||
# If invalid pipeline ID was specified
|
# If invalid pipeline ID was specified
|
||||||
@ -1670,6 +1689,12 @@ class PipelineStorageCollectionWebsocket(
|
|||||||
if item_id is None:
|
if item_id is None:
|
||||||
item_id = self.storage_collection.async_get_preferred_item()
|
item_id = self.storage_collection.async_get_preferred_item()
|
||||||
|
|
||||||
|
if item_id.startswith("conversation.") and hass.states.get(item_id):
|
||||||
|
connection.send_result(
|
||||||
|
msg["id"], _async_get_pipeline_from_conversation_entity(hass, item_id)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if item_id not in self.storage_collection.data:
|
if item_id not in self.storage_collection.data:
|
||||||
connection.send_error(
|
connection.send_error(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
|
@ -663,7 +663,10 @@
|
|||||||
# name: test_stt_stream_failed.2
|
# name: test_stt_stream_failed.2
|
||||||
None
|
None
|
||||||
# ---
|
# ---
|
||||||
# name: test_text_only_pipeline
|
# name: test_text_only_pipeline.3
|
||||||
|
None
|
||||||
|
# ---
|
||||||
|
# name: test_text_only_pipeline[extra_msg0]
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
@ -673,7 +676,7 @@
|
|||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_text_only_pipeline.1
|
# name: test_text_only_pipeline[extra_msg0].1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': 'mock-conversation-id',
|
'conversation_id': 'mock-conversation-id',
|
||||||
'device_id': 'mock-device-id',
|
'device_id': 'mock-device-id',
|
||||||
@ -682,7 +685,7 @@
|
|||||||
'language': 'en',
|
'language': 'en',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_text_only_pipeline.2
|
# name: test_text_only_pipeline[extra_msg0].2
|
||||||
dict({
|
dict({
|
||||||
'intent_output': dict({
|
'intent_output': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': None,
|
||||||
@ -704,7 +707,51 @@
|
|||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_text_only_pipeline.3
|
# name: test_text_only_pipeline[extra_msg0].3
|
||||||
|
None
|
||||||
|
# ---
|
||||||
|
# name: test_text_only_pipeline[extra_msg1]
|
||||||
|
dict({
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'runner_data': dict({
|
||||||
|
'stt_binary_handler_id': None,
|
||||||
|
'timeout': 300,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_text_only_pipeline[extra_msg1].1
|
||||||
|
dict({
|
||||||
|
'conversation_id': 'mock-conversation-id',
|
||||||
|
'device_id': 'mock-device-id',
|
||||||
|
'engine': 'conversation.home_assistant',
|
||||||
|
'intent_input': 'Are the lights on?',
|
||||||
|
'language': 'en',
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_text_only_pipeline[extra_msg1].2
|
||||||
|
dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'conversation_id': None,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'code': 'no_valid_targets',
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'error',
|
||||||
|
'speech': dict({
|
||||||
|
'plain': dict({
|
||||||
|
'extra_data': None,
|
||||||
|
'speech': 'Sorry, I am not aware of any area called are',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
# name: test_text_only_pipeline[extra_msg1].3
|
||||||
None
|
None
|
||||||
# ---
|
# ---
|
||||||
# name: test_text_pipeline_timeout
|
# name: test_text_pipeline_timeout
|
||||||
|
@ -5,6 +5,7 @@ import base64
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import ANY, patch
|
from unittest.mock import ANY, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||||
@ -23,11 +24,19 @@ from tests.common import MockConfigEntry
|
|||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"extra_msg",
|
||||||
|
[
|
||||||
|
{},
|
||||||
|
{"pipeline": "conversation.home_assistant"},
|
||||||
|
],
|
||||||
|
)
|
||||||
async def test_text_only_pipeline(
|
async def test_text_only_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
hass_ws_client: WebSocketGenerator,
|
hass_ws_client: WebSocketGenerator,
|
||||||
init_components,
|
init_components,
|
||||||
snapshot: SnapshotAssertion,
|
snapshot: SnapshotAssertion,
|
||||||
|
extra_msg: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
"""Test events from a pipeline run with text input (no STT/TTS)."""
|
||||||
events = []
|
events = []
|
||||||
@ -42,6 +51,7 @@ async def test_text_only_pipeline(
|
|||||||
"conversation_id": "mock-conversation-id",
|
"conversation_id": "mock-conversation-id",
|
||||||
"device_id": "mock-device-id",
|
"device_id": "mock-device-id",
|
||||||
}
|
}
|
||||||
|
| extra_msg
|
||||||
)
|
)
|
||||||
|
|
||||||
# result
|
# result
|
||||||
@ -1180,6 +1190,31 @@ async def test_get_pipeline(
|
|||||||
"wake_word_id": None,
|
"wake_word_id": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Get conversation agent as pipeline
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/get",
|
||||||
|
"pipeline_id": "conversation.home_assistant",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
assert msg["result"] == {
|
||||||
|
"conversation_engine": "conversation.home_assistant",
|
||||||
|
"conversation_language": "en",
|
||||||
|
"id": ANY,
|
||||||
|
"language": "en",
|
||||||
|
"name": "Home Assistant",
|
||||||
|
# It found these defaults
|
||||||
|
"stt_engine": "test",
|
||||||
|
"stt_language": "en-US",
|
||||||
|
"tts_engine": "test",
|
||||||
|
"tts_language": "en-US",
|
||||||
|
"tts_voice": "james_earl_jones",
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/get",
|
"type": "assist_pipeline/pipeline/get",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user