diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index 677d2a56664..72933a51167 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -39,12 +39,30 @@ _LOGGER = logging.getLogger(__name__) STORAGE_KEY = f"{DOMAIN}.pipelines" STORAGE_VERSION = 1 -STORAGE_FIELDS = { - vol.Optional("conversation_engine", default=None): vol.Any(str, None), +ENGINE_LANGUAGE_PAIRS = ( + ("stt_engine", "stt_language"), + ("tts_engine", "tts_language"), +) + + +def validate_language(data: dict[str, Any]) -> Any: + """Validate language settings.""" + for engine, language in ENGINE_LANGUAGE_PAIRS: + if data[engine] is not None and data[language] is None: + raise vol.Invalid(f"Need language {language} for {engine} {data[engine]}") + return data + + +PIPELINE_FIELDS = { + vol.Required("conversation_engine"): str, + vol.Required("conversation_language"): str, vol.Required("language"): str, vol.Required("name"): str, - vol.Optional("stt_engine", default=None): vol.Any(str, None), - vol.Optional("tts_engine", default=None): vol.Any(str, None), + vol.Required("stt_engine"): vol.Any(str, None), + vol.Required("stt_language"): vol.Any(str, None), + vol.Required("tts_engine"): vol.Any(str, None), + vol.Required("tts_language"): vol.Any(str, None), + vol.Required("tts_voice"): vol.Any(str, None), } STORED_PIPELINE_RUNS = 10 @@ -67,11 +85,15 @@ async def async_get_pipeline( # configured language return await pipeline_data.pipeline_store.async_create_item( { - "name": hass.config.language, + "conversation_engine": None, + "conversation_language": None, "language": hass.config.language, - "stt_engine": None, # first engine - "conversation_engine": None, # first agent - "tts_engine": None, # first engine + "name": hass.config.language, + "stt_engine": None, + "stt_language": None, + "tts_engine": None, + "tts_language": None, + "tts_voice": None, } ) @@ -108,11 +130,15 @@ PipelineEventCallback = Callable[[PipelineEvent], None] class Pipeline: """A voice assistant pipeline.""" - conversation_engine: str | None + conversation_engine: str + conversation_language: str language: str name: str stt_engine: str | None + stt_language: str | None tts_engine: str | None + tts_language: str | None + tts_voice: str | None id: str = field(default_factory=ulid_util.ulid) @@ -120,11 +146,15 @@ class Pipeline: """Return a JSON serializable representation for storage.""" return { "conversation_engine": self.conversation_engine, + "conversation_language": self.conversation_language, "id": self.id, "language": self.language, "name": self.name, "stt_engine": self.stt_engine, + "stt_language": self.stt_language, "tts_engine": self.tts_engine, + "tts_language": self.tts_language, + "tts_voice": self.tts_voice, } @@ -591,8 +621,6 @@ class PipelineStorageCollection( ): """Pipeline storage collection.""" - CREATE_UPDATE_SCHEMA = vol.Schema(STORAGE_FIELDS) - _preferred_item: str | None = None async def _async_load_data(self) -> SerializedPipelineStorageCollection | None: @@ -606,8 +634,8 @@ class PipelineStorageCollection( async def _process_create_data(self, data: dict) -> dict: """Validate the config is valid.""" - # We don't need to validate, the WS API has already validated - return data + validated_data: dict = validate_language(data) + return validated_data @callback def _get_suggested_id(self, info: dict) -> str: @@ -616,6 +644,7 @@ class PipelineStorageCollection( async def _update_data(self, item: Pipeline, update_data: dict) -> Pipeline: """Return a new updated item.""" + update_data = validate_language(update_data) return Pipeline(id=item.id, **update_data) def _create_item(self, item_id: str, data: dict) -> Pipeline: @@ -789,6 +818,10 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> None: ) await pipeline_store.async_load() PipelineStorageCollectionWebsocket( - pipeline_store, f"{DOMAIN}/pipeline", "pipeline", STORAGE_FIELDS, STORAGE_FIELDS + pipeline_store, + f"{DOMAIN}/pipeline", + "pipeline", + PIPELINE_FIELDS, + PIPELINE_FIELDS, ).async_setup(hass) hass.data[DOMAIN] = PipelineData({}, pipeline_store) diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 1517e5f53a3..47037869af6 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -22,24 +22,36 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None: pipelines = [ { "conversation_engine": "conversation_engine_1", + "conversation_language": "language_1", "language": "language_1", "name": "name_1", "stt_engine": "stt_engine_1", + "stt_language": "language_1", "tts_engine": "tts_engine_1", + "tts_language": "language_1", + "tts_voice": "Arnold Schwarzenegger", }, { "conversation_engine": "conversation_engine_2", + "conversation_language": "language_2", "language": "language_2", "name": "name_2", "stt_engine": "stt_engine_2", + "stt_language": "language_1", "tts_engine": "tts_engine_2", + "tts_language": "language_2", + "tts_voice": "The Voice", }, { - "conversation_engine": None, + "conversation_engine": "conversation_engine_3", + "conversation_language": "language_3", "language": "language_3", "name": "name_3", "stt_engine": None, + "stt_language": None, "tts_engine": None, + "tts_language": None, + "tts_voice": None, }, ] pipeline_ids = [] @@ -77,27 +89,39 @@ async def test_loading_datasets_from_storage( "items": [ { "conversation_engine": "conversation_engine_1", + "conversation_language": "language_1", "id": "01GX8ZWBAQYWNB1XV3EXEZ75DY", "language": "language_1", "name": "name_1", "stt_engine": "stt_engine_1", + "stt_language": "language_1", "tts_engine": "tts_engine_1", + "tts_language": "language_1", + "tts_voice": "Arnold Schwarzenegger", }, { "conversation_engine": "conversation_engine_2", + "conversation_language": "language_2", "id": "01GX8ZWBAQTKFQNK4W7Q4CTRCX", "language": "language_2", "name": "name_2", "stt_engine": "stt_engine_2", + "stt_language": "language_2", "tts_engine": "tts_engine_2", + "tts_language": "language_2", + "tts_voice": "The Voice", }, { - "conversation_engine": None, + "conversation_engine": "conversation_engine_3", + "conversation_language": "language_3", "id": "01GX8ZWBAQSV1HP3WGJPFWEJ8J", "language": "language_3", "name": "name_3", "stt_engine": None, + "stt_language": None, "tts_engine": None, + "tts_language": None, + "tts_voice": None, }, ], "preferred_item": "01GX8ZWBAQYWNB1XV3EXEZ75DY", diff --git a/tests/components/assist_pipeline/test_select.py b/tests/components/assist_pipeline/test_select.py index 540ed98da13..b7fc232494b 100644 --- a/tests/components/assist_pipeline/test_select.py +++ b/tests/components/assist_pipeline/test_select.py @@ -47,8 +47,12 @@ async def pipeline_1( "name": "Test 1", "language": "en-US", "conversation_engine": None, + "conversation_language": "en-US", "tts_engine": None, + "tts_language": None, + "tts_voice": None, "stt_engine": None, + "stt_language": None, } ) @@ -63,8 +67,12 @@ async def pipeline_2( "name": "Test 2", "language": "en-US", "conversation_engine": None, + "conversation_language": "en-US", "tts_engine": None, + "tts_language": None, + "tts_voice": None, "stt_engine": None, + "stt_language": None, } ) diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 6de01e74ea9..5a696c50f63 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -585,32 +585,44 @@ async def test_add_pipeline( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", "language": "test_language", "name": "test_name", "stt_engine": "test_stt_engine", + "stt_language": "test_language", "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", } ) msg = await client.receive_json() assert msg["success"] assert msg["result"] == { "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", "id": ANY, "language": "test_language", "name": "test_name", "stt_engine": "test_stt_engine", + "stt_language": "test_language", "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", } assert len(pipeline_store.data) == 1 pipeline = pipeline_store.data[msg["result"]["id"]] assert pipeline == Pipeline( conversation_engine="test_conversation_engine", + conversation_language="test_language", id=msg["result"]["id"], language="test_language", name="test_name", stt_engine="test_stt_engine", + stt_language="test_language", tts_engine="test_tts_engine", + tts_language="test_language", + tts_voice="Arnold Schwarzenegger", ) await client.send_json_auto_id( @@ -621,26 +633,52 @@ async def test_add_pipeline( } ) msg = await client.receive_json() - assert msg["success"] - assert msg["result"] == { - "conversation_engine": None, - "id": ANY, - "language": "test_language", - "name": "test_name", - "stt_engine": None, - "tts_engine": None, - } + assert not msg["success"] - assert len(pipeline_store.data) == 2 - pipeline = pipeline_store.data[msg["result"]["id"]] - assert pipeline == Pipeline( - conversation_engine=None, - id=msg["result"]["id"], - language="test_language", - name="test_name", - stt_engine=None, - tts_engine=None, + +async def test_add_pipeline_missing_language( + hass: HomeAssistant, hass_ws_client: WebSocketGenerator, init_components +) -> None: + """Test we can't add a pipeline without specifying stt or tts language.""" + client = await hass_ws_client(hass) + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_store = pipeline_data.pipeline_store + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "stt_language": None, + "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", + } ) + msg = await client.receive_json() + assert not msg["success"] + assert len(pipeline_store.data) == 0 + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/create", + "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", + "language": "test_language", + "name": "test_name", + "stt_engine": "test_stt_engine", + "stt_language": "test_language", + "tts_engine": "test_tts_engine", + "tts_language": None, + "tts_voice": "Arnold Schwarzenegger", + } + ) + msg = await client.receive_json() + assert not msg["success"] + assert len(pipeline_store.data) == 0 async def test_delete_pipeline( @@ -655,10 +693,14 @@ async def test_delete_pipeline( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", "language": "test_language", "name": "test_name", "stt_engine": "test_stt_engine", + "stt_language": "test_language", "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", } ) msg = await client.receive_json() @@ -669,10 +711,14 @@ async def test_delete_pipeline( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", "language": "test_language", "name": "test_name", "stt_engine": "test_stt_engine", + "stt_language": "test_language", "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", } ) msg = await client.receive_json() @@ -817,10 +863,14 @@ async def test_list_pipelines( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", "language": "test_language", "name": "test_name", "stt_engine": "test_stt_engine", + "stt_language": "test_language", "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", } ) msg = await client.receive_json() @@ -834,11 +884,15 @@ async def test_list_pipelines( "pipelines": [ { "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", "id": ANY, "language": "test_language", "name": "test_name", "stt_engine": "test_stt_engine", + "stt_language": "test_language", "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", } ], "preferred_pipeline": ANY, @@ -857,11 +911,15 @@ async def test_update_pipeline( { "type": "assist_pipeline/pipeline/update", "conversation_engine": "new_conversation_engine", + "conversation_language": "new_conversation_language", "language": "new_language", "name": "new_name", "pipeline_id": "no_such_pipeline", "stt_engine": "new_stt_engine", + "stt_language": "new_stt_language", "tts_engine": "new_tts_engine", + "tts_language": "new_tts_language", + "tts_voice": "new_tts_voice", } ) msg = await client.receive_json() @@ -875,10 +933,14 @@ async def test_update_pipeline( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", "language": "test_language", "name": "test_name", "stt_engine": "test_stt_engine", + "stt_language": "test_language", "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", } ) msg = await client.receive_json() @@ -890,65 +952,89 @@ async def test_update_pipeline( { "type": "assist_pipeline/pipeline/update", "conversation_engine": "new_conversation_engine", + "conversation_language": "new_conversation_language", "language": "new_language", "name": "new_name", "pipeline_id": pipeline_id, "stt_engine": "new_stt_engine", + "stt_language": "new_stt_language", "tts_engine": "new_tts_engine", + "tts_language": "new_tts_language", + "tts_voice": "new_tts_voice", } ) msg = await client.receive_json() assert msg["success"] assert msg["result"] == { "conversation_engine": "new_conversation_engine", + "conversation_language": "new_conversation_language", "id": pipeline_id, "language": "new_language", "name": "new_name", "stt_engine": "new_stt_engine", + "stt_language": "new_stt_language", "tts_engine": "new_tts_engine", + "tts_language": "new_tts_language", + "tts_voice": "new_tts_voice", } assert len(pipeline_store.data) == 1 pipeline = pipeline_store.data[pipeline_id] assert pipeline == Pipeline( conversation_engine="new_conversation_engine", + conversation_language="new_conversation_language", id=pipeline_id, language="new_language", name="new_name", stt_engine="new_stt_engine", + stt_language="new_stt_language", tts_engine="new_tts_engine", + tts_language="new_tts_language", + tts_voice="new_tts_voice", ) await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/update", - "conversation_engine": None, + "conversation_engine": "new_conversation_engine", + "conversation_language": "new_conversation_language", "language": "new_language", "name": "new_name", "pipeline_id": pipeline_id, "stt_engine": None, + "stt_language": None, "tts_engine": None, + "tts_language": None, + "tts_voice": None, } ) msg = await client.receive_json() assert msg["success"] assert msg["result"] == { - "conversation_engine": None, + "conversation_engine": "new_conversation_engine", + "conversation_language": "new_conversation_language", "id": pipeline_id, "language": "new_language", "name": "new_name", "stt_engine": None, + "stt_language": None, "tts_engine": None, + "tts_language": None, + "tts_voice": None, } pipeline = pipeline_store.data[pipeline_id] assert pipeline == Pipeline( - conversation_engine=None, + conversation_engine="new_conversation_engine", + conversation_language="new_conversation_language", id=pipeline_id, language="new_language", name="new_name", stt_engine=None, + stt_language=None, tts_engine=None, + tts_language=None, + tts_voice=None, ) @@ -964,10 +1050,14 @@ async def test_set_preferred_pipeline( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", "language": "test_language", "name": "test_name", "stt_engine": "test_stt_engine", + "stt_language": "test_language", "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", } ) msg = await client.receive_json() @@ -978,10 +1068,14 @@ async def test_set_preferred_pipeline( { "type": "assist_pipeline/pipeline/create", "conversation_engine": "test_conversation_engine", + "conversation_language": "test_language", "language": "test_language", "name": "test_name", "stt_engine": "test_stt_engine", + "stt_language": "test_language", "tts_engine": "test_tts_engine", + "tts_language": "test_language", + "tts_voice": "Arnold Schwarzenegger", } ) msg = await client.receive_json()