mirror of
https://github.com/home-assistant/core.git
synced 2025-07-16 17:57:11 +00:00
Add additional parameters to assist pipelines (#91619)
* Add additional parameters to assist pipelines * Improve WS schema validation * Tweak * Add test * Address review comments
This commit is contained in:
parent
b4e0a1f1fc
commit
0525ce59d7
@ -39,12 +39,30 @@ _LOGGER = logging.getLogger(__name__)
|
|||||||
STORAGE_KEY = f"{DOMAIN}.pipelines"
|
STORAGE_KEY = f"{DOMAIN}.pipelines"
|
||||||
STORAGE_VERSION = 1
|
STORAGE_VERSION = 1
|
||||||
|
|
||||||
STORAGE_FIELDS = {
|
ENGINE_LANGUAGE_PAIRS = (
|
||||||
vol.Optional("conversation_engine", default=None): vol.Any(str, None),
|
("stt_engine", "stt_language"),
|
||||||
|
("tts_engine", "tts_language"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_language(data: dict[str, Any]) -> Any:
|
||||||
|
"""Validate language settings."""
|
||||||
|
for engine, language in ENGINE_LANGUAGE_PAIRS:
|
||||||
|
if data[engine] is not None and data[language] is None:
|
||||||
|
raise vol.Invalid(f"Need language {language} for {engine} {data[engine]}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
PIPELINE_FIELDS = {
|
||||||
|
vol.Required("conversation_engine"): str,
|
||||||
|
vol.Required("conversation_language"): str,
|
||||||
vol.Required("language"): str,
|
vol.Required("language"): str,
|
||||||
vol.Required("name"): str,
|
vol.Required("name"): str,
|
||||||
vol.Optional("stt_engine", default=None): vol.Any(str, None),
|
vol.Required("stt_engine"): vol.Any(str, None),
|
||||||
vol.Optional("tts_engine", default=None): vol.Any(str, None),
|
vol.Required("stt_language"): vol.Any(str, None),
|
||||||
|
vol.Required("tts_engine"): vol.Any(str, None),
|
||||||
|
vol.Required("tts_language"): vol.Any(str, None),
|
||||||
|
vol.Required("tts_voice"): vol.Any(str, None),
|
||||||
}
|
}
|
||||||
|
|
||||||
STORED_PIPELINE_RUNS = 10
|
STORED_PIPELINE_RUNS = 10
|
||||||
@ -67,11 +85,15 @@ async def async_get_pipeline(
|
|||||||
# configured language
|
# configured language
|
||||||
return await pipeline_data.pipeline_store.async_create_item(
|
return await pipeline_data.pipeline_store.async_create_item(
|
||||||
{
|
{
|
||||||
"name": hass.config.language,
|
"conversation_engine": None,
|
||||||
|
"conversation_language": None,
|
||||||
"language": hass.config.language,
|
"language": hass.config.language,
|
||||||
"stt_engine": None, # first engine
|
"name": hass.config.language,
|
||||||
"conversation_engine": None, # first agent
|
"stt_engine": None,
|
||||||
"tts_engine": None, # first engine
|
"stt_language": None,
|
||||||
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -108,11 +130,15 @@ PipelineEventCallback = Callable[[PipelineEvent], None]
|
|||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""A voice assistant pipeline."""
|
"""A voice assistant pipeline."""
|
||||||
|
|
||||||
conversation_engine: str | None
|
conversation_engine: str
|
||||||
|
conversation_language: str
|
||||||
language: str
|
language: str
|
||||||
name: str
|
name: str
|
||||||
stt_engine: str | None
|
stt_engine: str | None
|
||||||
|
stt_language: str | None
|
||||||
tts_engine: str | None
|
tts_engine: str | None
|
||||||
|
tts_language: str | None
|
||||||
|
tts_voice: str | None
|
||||||
|
|
||||||
id: str = field(default_factory=ulid_util.ulid)
|
id: str = field(default_factory=ulid_util.ulid)
|
||||||
|
|
||||||
@ -120,11 +146,15 @@ class Pipeline:
|
|||||||
"""Return a JSON serializable representation for storage."""
|
"""Return a JSON serializable representation for storage."""
|
||||||
return {
|
return {
|
||||||
"conversation_engine": self.conversation_engine,
|
"conversation_engine": self.conversation_engine,
|
||||||
|
"conversation_language": self.conversation_language,
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"language": self.language,
|
"language": self.language,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"stt_engine": self.stt_engine,
|
"stt_engine": self.stt_engine,
|
||||||
|
"stt_language": self.stt_language,
|
||||||
"tts_engine": self.tts_engine,
|
"tts_engine": self.tts_engine,
|
||||||
|
"tts_language": self.tts_language,
|
||||||
|
"tts_voice": self.tts_voice,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -591,8 +621,6 @@ class PipelineStorageCollection(
|
|||||||
):
|
):
|
||||||
"""Pipeline storage collection."""
|
"""Pipeline storage collection."""
|
||||||
|
|
||||||
CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS)
|
|
||||||
|
|
||||||
_preferred_item: str | None = None
|
_preferred_item: str | None = None
|
||||||
|
|
||||||
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
|
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
|
||||||
@ -606,8 +634,8 @@ class PipelineStorageCollection(
|
|||||||
|
|
||||||
async def _process_create_data(self, data: dict) -> dict:
|
async def _process_create_data(self, data: dict) -> dict:
|
||||||
"""Validate the config is valid."""
|
"""Validate the config is valid."""
|
||||||
# We don't need to validate, the WS API has already validated
|
validated_data: dict = validate_language(data)
|
||||||
return data
|
return validated_data
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _get_suggested_id(self, info: dict) -> str:
|
def _get_suggested_id(self, info: dict) -> str:
|
||||||
@ -616,6 +644,7 @@ class PipelineStorageCollection(
|
|||||||
|
|
||||||
async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline:
|
async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline:
|
||||||
"""Return a new updated item."""
|
"""Return a new updated item."""
|
||||||
|
update_data = validate_language(update_data)
|
||||||
return Pipeline(id=item.id, **update_data)
|
return Pipeline(id=item.id, **update_data)
|
||||||
|
|
||||||
def _create_item(self, item_id: str, data: dict) -> Pipeline:
|
def _create_item(self, item_id: str, data: dict) -> Pipeline:
|
||||||
@ -789,6 +818,10 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> None:
|
|||||||
)
|
)
|
||||||
await pipeline_store.async_load()
|
await pipeline_store.async_load()
|
||||||
PipelineStorageCollectionWebsocket(
|
PipelineStorageCollectionWebsocket(
|
||||||
pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS
|
pipeline_store,
|
||||||
|
f"{DOMAIN}/pipeline",
|
||||||
|
"pipeline",
|
||||||
|
PIPELINE_FIELDS,
|
||||||
|
PIPELINE_FIELDS,
|
||||||
).async_setup(hass)
|
).async_setup(hass)
|
||||||
hass.data[DOMAIN] = PipelineData({}, pipeline_store)
|
hass.data[DOMAIN] = PipelineData({}, pipeline_store)
|
||||||
|
@ -22,24 +22,36 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
|||||||
pipelines = [
|
pipelines = [
|
||||||
{
|
{
|
||||||
"conversation_engine": "conversation_engine_1",
|
"conversation_engine": "conversation_engine_1",
|
||||||
|
"conversation_language": "language_1",
|
||||||
"language": "language_1",
|
"language": "language_1",
|
||||||
"name": "name_1",
|
"name": "name_1",
|
||||||
"stt_engine": "stt_engine_1",
|
"stt_engine": "stt_engine_1",
|
||||||
|
"stt_language": "language_1",
|
||||||
"tts_engine": "tts_engine_1",
|
"tts_engine": "tts_engine_1",
|
||||||
|
"tts_language": "language_1",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"conversation_engine": "conversation_engine_2",
|
"conversation_engine": "conversation_engine_2",
|
||||||
|
"conversation_language": "language_2",
|
||||||
"language": "language_2",
|
"language": "language_2",
|
||||||
"name": "name_2",
|
"name": "name_2",
|
||||||
"stt_engine": "stt_engine_2",
|
"stt_engine": "stt_engine_2",
|
||||||
|
"stt_language": "language_1",
|
||||||
"tts_engine": "tts_engine_2",
|
"tts_engine": "tts_engine_2",
|
||||||
|
"tts_language": "language_2",
|
||||||
|
"tts_voice": "The Voice",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"conversation_engine": None,
|
"conversation_engine": "conversation_engine_3",
|
||||||
|
"conversation_language": "language_3",
|
||||||
"language": "language_3",
|
"language": "language_3",
|
||||||
"name": "name_3",
|
"name": "name_3",
|
||||||
"stt_engine": None,
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
"tts_engine": None,
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
pipeline_ids = []
|
pipeline_ids = []
|
||||||
@ -77,27 +89,39 @@ async def test_loading_datasets_from_storage(
|
|||||||
"items": [
|
"items": [
|
||||||
{
|
{
|
||||||
"conversation_engine": "conversation_engine_1",
|
"conversation_engine": "conversation_engine_1",
|
||||||
|
"conversation_language": "language_1",
|
||||||
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
"id": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||||
"language": "language_1",
|
"language": "language_1",
|
||||||
"name": "name_1",
|
"name": "name_1",
|
||||||
"stt_engine": "stt_engine_1",
|
"stt_engine": "stt_engine_1",
|
||||||
|
"stt_language": "language_1",
|
||||||
"tts_engine": "tts_engine_1",
|
"tts_engine": "tts_engine_1",
|
||||||
|
"tts_language": "language_1",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"conversation_engine": "conversation_engine_2",
|
"conversation_engine": "conversation_engine_2",
|
||||||
|
"conversation_language": "language_2",
|
||||||
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
"id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX",
|
||||||
"language": "language_2",
|
"language": "language_2",
|
||||||
"name": "name_2",
|
"name": "name_2",
|
||||||
"stt_engine": "stt_engine_2",
|
"stt_engine": "stt_engine_2",
|
||||||
|
"stt_language": "language_2",
|
||||||
"tts_engine": "tts_engine_2",
|
"tts_engine": "tts_engine_2",
|
||||||
|
"tts_language": "language_2",
|
||||||
|
"tts_voice": "The Voice",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"conversation_engine": None,
|
"conversation_engine": "conversation_engine_3",
|
||||||
|
"conversation_language": "language_3",
|
||||||
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
"id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J",
|
||||||
"language": "language_3",
|
"language": "language_3",
|
||||||
"name": "name_3",
|
"name": "name_3",
|
||||||
"stt_engine": None,
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
"tts_engine": None,
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
"preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY",
|
||||||
|
@ -47,8 +47,12 @@ async def pipeline_1(
|
|||||||
"name": "Test 1",
|
"name": "Test 1",
|
||||||
"language": "en-US",
|
"language": "en-US",
|
||||||
"conversation_engine": None,
|
"conversation_engine": None,
|
||||||
|
"conversation_language": "en-US",
|
||||||
"tts_engine": None,
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
"stt_engine": None,
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -63,8 +67,12 @@ async def pipeline_2(
|
|||||||
"name": "Test 2",
|
"name": "Test 2",
|
||||||
"language": "en-US",
|
"language": "en-US",
|
||||||
"conversation_engine": None,
|
"conversation_engine": None,
|
||||||
|
"conversation_language": "en-US",
|
||||||
"tts_engine": None,
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
"stt_engine": None,
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -585,32 +585,44 @@ async def test_add_pipeline(
|
|||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/create",
|
"type": "assist_pipeline/pipeline/create",
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
"language": "test_language",
|
"language": "test_language",
|
||||||
"name": "test_name",
|
"name": "test_name",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
"id": ANY,
|
"id": ANY,
|
||||||
"language": "test_language",
|
"language": "test_language",
|
||||||
"name": "test_name",
|
"name": "test_name",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert len(pipeline_store.data) == 1
|
assert len(pipeline_store.data) == 1
|
||||||
pipeline = pipeline_store.data[msg["result"]["id"]]
|
pipeline = pipeline_store.data[msg["result"]["id"]]
|
||||||
assert pipeline == Pipeline(
|
assert pipeline == Pipeline(
|
||||||
conversation_engine="test_conversation_engine",
|
conversation_engine="test_conversation_engine",
|
||||||
|
conversation_language="test_language",
|
||||||
id=msg["result"]["id"],
|
id=msg["result"]["id"],
|
||||||
language="test_language",
|
language="test_language",
|
||||||
name="test_name",
|
name="test_name",
|
||||||
stt_engine="test_stt_engine",
|
stt_engine="test_stt_engine",
|
||||||
|
stt_language="test_language",
|
||||||
tts_engine="test_tts_engine",
|
tts_engine="test_tts_engine",
|
||||||
|
tts_language="test_language",
|
||||||
|
tts_voice="Arnold Schwarzenegger",
|
||||||
)
|
)
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
@ -621,26 +633,52 @@ async def test_add_pipeline(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert not msg["success"]
|
||||||
assert msg["result"] == {
|
|
||||||
"conversation_engine": None,
|
|
||||||
"id": ANY,
|
|
||||||
"language": "test_language",
|
|
||||||
"name": "test_name",
|
|
||||||
"stt_engine": None,
|
|
||||||
"tts_engine": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
assert len(pipeline_store.data) == 2
|
|
||||||
pipeline = pipeline_store.data[msg["result"]["id"]]
|
async def test_add_pipeline_missing_language(
|
||||||
assert pipeline == Pipeline(
|
hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components
|
||||||
conversation_engine=None,
|
) -> None:
|
||||||
id=msg["result"]["id"],
|
"""Test we can't add a pipeline without specifying stt or tts language."""
|
||||||
language="test_language",
|
client = await hass_ws_client(hass)
|
||||||
name="test_name",
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
stt_engine=None,
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
tts_engine=None,
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": None,
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert len(pipeline_store.data) == 0
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
|
"language": "test_language",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert not msg["success"]
|
||||||
|
assert len(pipeline_store.data) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_delete_pipeline(
|
async def test_delete_pipeline(
|
||||||
@ -655,10 +693,14 @@ async def test_delete_pipeline(
|
|||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/create",
|
"type": "assist_pipeline/pipeline/create",
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
"language": "test_language",
|
"language": "test_language",
|
||||||
"name": "test_name",
|
"name": "test_name",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -669,10 +711,14 @@ async def test_delete_pipeline(
|
|||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/create",
|
"type": "assist_pipeline/pipeline/create",
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
"language": "test_language",
|
"language": "test_language",
|
||||||
"name": "test_name",
|
"name": "test_name",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -817,10 +863,14 @@ async def test_list_pipelines(
|
|||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/create",
|
"type": "assist_pipeline/pipeline/create",
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
"language": "test_language",
|
"language": "test_language",
|
||||||
"name": "test_name",
|
"name": "test_name",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -834,11 +884,15 @@ async def test_list_pipelines(
|
|||||||
"pipelines": [
|
"pipelines": [
|
||||||
{
|
{
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
"id": ANY,
|
"id": ANY,
|
||||||
"language": "test_language",
|
"language": "test_language",
|
||||||
"name": "test_name",
|
"name": "test_name",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"preferred_pipeline": ANY,
|
"preferred_pipeline": ANY,
|
||||||
@ -857,11 +911,15 @@ async def test_update_pipeline(
|
|||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/update",
|
"type": "assist_pipeline/pipeline/update",
|
||||||
"conversation_engine": "new_conversation_engine",
|
"conversation_engine": "new_conversation_engine",
|
||||||
|
"conversation_language": "new_conversation_language",
|
||||||
"language": "new_language",
|
"language": "new_language",
|
||||||
"name": "new_name",
|
"name": "new_name",
|
||||||
"pipeline_id": "no_such_pipeline",
|
"pipeline_id": "no_such_pipeline",
|
||||||
"stt_engine": "new_stt_engine",
|
"stt_engine": "new_stt_engine",
|
||||||
|
"stt_language": "new_stt_language",
|
||||||
"tts_engine": "new_tts_engine",
|
"tts_engine": "new_tts_engine",
|
||||||
|
"tts_language": "new_tts_language",
|
||||||
|
"tts_voice": "new_tts_voice",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -875,10 +933,14 @@ async def test_update_pipeline(
|
|||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/create",
|
"type": "assist_pipeline/pipeline/create",
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
"language": "test_language",
|
"language": "test_language",
|
||||||
"name": "test_name",
|
"name": "test_name",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -890,65 +952,89 @@ async def test_update_pipeline(
|
|||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/update",
|
"type": "assist_pipeline/pipeline/update",
|
||||||
"conversation_engine": "new_conversation_engine",
|
"conversation_engine": "new_conversation_engine",
|
||||||
|
"conversation_language": "new_conversation_language",
|
||||||
"language": "new_language",
|
"language": "new_language",
|
||||||
"name": "new_name",
|
"name": "new_name",
|
||||||
"pipeline_id": pipeline_id,
|
"pipeline_id": pipeline_id,
|
||||||
"stt_engine": "new_stt_engine",
|
"stt_engine": "new_stt_engine",
|
||||||
|
"stt_language": "new_stt_language",
|
||||||
"tts_engine": "new_tts_engine",
|
"tts_engine": "new_tts_engine",
|
||||||
|
"tts_language": "new_tts_language",
|
||||||
|
"tts_voice": "new_tts_voice",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"conversation_engine": "new_conversation_engine",
|
"conversation_engine": "new_conversation_engine",
|
||||||
|
"conversation_language": "new_conversation_language",
|
||||||
"id": pipeline_id,
|
"id": pipeline_id,
|
||||||
"language": "new_language",
|
"language": "new_language",
|
||||||
"name": "new_name",
|
"name": "new_name",
|
||||||
"stt_engine": "new_stt_engine",
|
"stt_engine": "new_stt_engine",
|
||||||
|
"stt_language": "new_stt_language",
|
||||||
"tts_engine": "new_tts_engine",
|
"tts_engine": "new_tts_engine",
|
||||||
|
"tts_language": "new_tts_language",
|
||||||
|
"tts_voice": "new_tts_voice",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert len(pipeline_store.data) == 1
|
assert len(pipeline_store.data) == 1
|
||||||
pipeline = pipeline_store.data[pipeline_id]
|
pipeline = pipeline_store.data[pipeline_id]
|
||||||
assert pipeline == Pipeline(
|
assert pipeline == Pipeline(
|
||||||
conversation_engine="new_conversation_engine",
|
conversation_engine="new_conversation_engine",
|
||||||
|
conversation_language="new_conversation_language",
|
||||||
id=pipeline_id,
|
id=pipeline_id,
|
||||||
language="new_language",
|
language="new_language",
|
||||||
name="new_name",
|
name="new_name",
|
||||||
stt_engine="new_stt_engine",
|
stt_engine="new_stt_engine",
|
||||||
|
stt_language="new_stt_language",
|
||||||
tts_engine="new_tts_engine",
|
tts_engine="new_tts_engine",
|
||||||
|
tts_language="new_tts_language",
|
||||||
|
tts_voice="new_tts_voice",
|
||||||
)
|
)
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/update",
|
"type": "assist_pipeline/pipeline/update",
|
||||||
"conversation_engine": None,
|
"conversation_engine": "new_conversation_engine",
|
||||||
|
"conversation_language": "new_conversation_language",
|
||||||
"language": "new_language",
|
"language": "new_language",
|
||||||
"name": "new_name",
|
"name": "new_name",
|
||||||
"pipeline_id": pipeline_id,
|
"pipeline_id": pipeline_id,
|
||||||
"stt_engine": None,
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
"tts_engine": None,
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"conversation_engine": None,
|
"conversation_engine": "new_conversation_engine",
|
||||||
|
"conversation_language": "new_conversation_language",
|
||||||
"id": pipeline_id,
|
"id": pipeline_id,
|
||||||
"language": "new_language",
|
"language": "new_language",
|
||||||
"name": "new_name",
|
"name": "new_name",
|
||||||
"stt_engine": None,
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
"tts_engine": None,
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
pipeline = pipeline_store.data[pipeline_id]
|
pipeline = pipeline_store.data[pipeline_id]
|
||||||
assert pipeline == Pipeline(
|
assert pipeline == Pipeline(
|
||||||
conversation_engine=None,
|
conversation_engine="new_conversation_engine",
|
||||||
|
conversation_language="new_conversation_language",
|
||||||
id=pipeline_id,
|
id=pipeline_id,
|
||||||
language="new_language",
|
language="new_language",
|
||||||
name="new_name",
|
name="new_name",
|
||||||
stt_engine=None,
|
stt_engine=None,
|
||||||
|
stt_language=None,
|
||||||
tts_engine=None,
|
tts_engine=None,
|
||||||
|
tts_language=None,
|
||||||
|
tts_voice=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -964,10 +1050,14 @@ async def test_set_preferred_pipeline(
|
|||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/create",
|
"type": "assist_pipeline/pipeline/create",
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
"language": "test_language",
|
"language": "test_language",
|
||||||
"name": "test_name",
|
"name": "test_name",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -978,10 +1068,14 @@ async def test_set_preferred_pipeline(
|
|||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/create",
|
"type": "assist_pipeline/pipeline/create",
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "test_conversation_engine",
|
||||||
|
"conversation_language": "test_language",
|
||||||
"language": "test_language",
|
"language": "test_language",
|
||||||
"name": "test_name",
|
"name": "test_name",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test_stt_engine",
|
||||||
|
"stt_language": "test_language",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test_tts_engine",
|
||||||
|
"tts_language": "test_language",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user