Allow targeting conversation agent as pipeline (#119556)

* Allow targetting conversation agent as pipeline

* Test that we can use a conversation entity as an assist pipeline

* Add test for WS get

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
This commit is contained in:
Paulus Schoutsen 2024-07-09 17:56:53 +02:00 committed by GitHub
parent 69ed730101
commit 154da1b18b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 111 additions and 4 deletions

View File

@ -259,6 +259,22 @@ async def async_create_default_pipeline(
return await pipeline_store.async_create_item(pipeline_settings) return await pipeline_store.async_create_item(pipeline_settings)
@callback
def _async_get_pipeline_from_conversation_entity(
hass: HomeAssistant, entity_id: str
) -> Pipeline:
"""Get a pipeline by conversation entity ID."""
entity = hass.states.get(entity_id)
settings = _async_resolve_default_pipeline_settings(
hass,
pipeline_name=entity.name if entity else entity_id,
conversation_engine_id=entity_id,
)
settings["id"] = entity_id
return Pipeline.from_json(settings)
@callback @callback
def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline: def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> Pipeline:
"""Get a pipeline by id or the preferred pipeline.""" """Get a pipeline by id or the preferred pipeline."""
@ -268,6 +284,9 @@ def async_get_pipeline(hass: HomeAssistant, pipeline_id: str | None = None) -> P
# A pipeline was not specified, use the preferred one # A pipeline was not specified, use the preferred one
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item() pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
if pipeline_id.startswith("conversation."):
return _async_get_pipeline_from_conversation_entity(hass, pipeline_id)
pipeline = pipeline_data.pipeline_store.data.get(pipeline_id) pipeline = pipeline_data.pipeline_store.data.get(pipeline_id)
# If invalid pipeline ID was specified # If invalid pipeline ID was specified
@ -1670,6 +1689,12 @@ class PipelineStorageCollectionWebsocket(
if item_id is None: if item_id is None:
item_id = self.storage_collection.async_get_preferred_item() item_id = self.storage_collection.async_get_preferred_item()
if item_id.startswith("conversation.") and hass.states.get(item_id):
connection.send_result(
msg["id"], _async_get_pipeline_from_conversation_entity(hass, item_id)
)
return
if item_id not in self.storage_collection.data: if item_id not in self.storage_collection.data:
connection.send_error( connection.send_error(
msg["id"], msg["id"],

View File

@ -663,7 +663,10 @@
# name: test_stt_stream_failed.2 # name: test_stt_stream_failed.2
None None
# --- # ---
# name: test_text_only_pipeline # name: test_text_only_pipeline.3
None
# ---
# name: test_text_only_pipeline[extra_msg0]
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': <ANY>, 'pipeline': <ANY>,
@ -673,7 +676,7 @@
}), }),
}) })
# --- # ---
# name: test_text_only_pipeline.1 # name: test_text_only_pipeline[extra_msg0].1
dict({ dict({
'conversation_id': 'mock-conversation-id', 'conversation_id': 'mock-conversation-id',
'device_id': 'mock-device-id', 'device_id': 'mock-device-id',
@ -682,7 +685,7 @@
'language': 'en', 'language': 'en',
}) })
# --- # ---
# name: test_text_only_pipeline.2 # name: test_text_only_pipeline[extra_msg0].2
dict({ dict({
'intent_output': dict({ 'intent_output': dict({
'conversation_id': None, 'conversation_id': None,
@ -704,7 +707,51 @@
}), }),
}) })
# --- # ---
# name: test_text_only_pipeline.3 # name: test_text_only_pipeline[extra_msg0].3
None
# ---
# name: test_text_only_pipeline[extra_msg1]
dict({
'language': 'en',
'pipeline': <ANY>,
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 300,
}),
})
# ---
# name: test_text_only_pipeline[extra_msg1].1
dict({
'conversation_id': 'mock-conversation-id',
'device_id': 'mock-device-id',
'engine': 'conversation.home_assistant',
'intent_input': 'Are the lights on?',
'language': 'en',
})
# ---
# name: test_text_only_pipeline[extra_msg1].2
dict({
'intent_output': dict({
'conversation_id': None,
'response': dict({
'card': dict({
}),
'data': dict({
'code': 'no_valid_targets',
}),
'language': 'en',
'response_type': 'error',
'speech': dict({
'plain': dict({
'extra_data': None,
'speech': 'Sorry, I am not aware of any area called are',
}),
}),
}),
}),
})
# ---
# name: test_text_only_pipeline[extra_msg1].3
None None
# --- # ---
# name: test_text_pipeline_timeout # name: test_text_pipeline_timeout

View File

@ -5,6 +5,7 @@ import base64
from typing import Any from typing import Any
from unittest.mock import ANY, patch from unittest.mock import ANY, patch
import pytest
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.const import DOMAIN
@ -23,11 +24,19 @@ from tests.common import MockConfigEntry
from tests.typing import WebSocketGenerator from tests.typing import WebSocketGenerator
@pytest.mark.parametrize(
"extra_msg",
[
{},
{"pipeline": "conversation.home_assistant"},
],
)
async def test_text_only_pipeline( async def test_text_only_pipeline(
hass: HomeAssistant, hass: HomeAssistant,
hass_ws_client: WebSocketGenerator, hass_ws_client: WebSocketGenerator,
init_components, init_components,
snapshot: SnapshotAssertion, snapshot: SnapshotAssertion,
extra_msg: dict[str, Any],
) -> None: ) -> None:
"""Test events from a pipeline run with text input (no STT/TTS).""" """Test events from a pipeline run with text input (no STT/TTS)."""
events = [] events = []
@ -42,6 +51,7 @@ async def test_text_only_pipeline(
"conversation_id": "mock-conversation-id", "conversation_id": "mock-conversation-id",
"device_id": "mock-device-id", "device_id": "mock-device-id",
} }
| extra_msg
) )
# result # result
@ -1180,6 +1190,31 @@ async def test_get_pipeline(
"wake_word_id": None, "wake_word_id": None,
} }
# Get conversation agent as pipeline
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/get",
"pipeline_id": "conversation.home_assistant",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "conversation.home_assistant",
"conversation_language": "en",
"id": ANY,
"language": "en",
"name": "Home Assistant",
# It found these defaults
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "james_earl_jones",
"wake_word_entity": None,
"wake_word_id": None,
}
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
"type": "assist_pipeline/pipeline/get", "type": "assist_pipeline/pipeline/get",