From b601fb17d353d8dfc55fec8d2283f158b3b7cca0 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Mon, 24 Apr 2023 20:00:52 +0200 Subject: [PATCH] Create a default assist pipeline on start (#91947) * Create a default assist pipeline on start * Minor adjustments * Address review comments * Remove tts.async_get_agent * Fix bugs, improve test coverage --- .../components/assist_pipeline/__init__.py | 2 +- .../components/assist_pipeline/pipeline.py | 134 +++++++--- .../assist_pipeline/websocket_api.py | 2 +- .../components/conversation/__init__.py | 20 +- tests/components/assist_pipeline/__init__.py | 55 ++++ tests/components/assist_pipeline/conftest.py | 60 +++-- .../assist_pipeline/snapshots/test_init.ambr | 12 +- .../snapshots/test_websocket.ambr | 44 ++-- .../assist_pipeline/test_pipeline.py | 236 +++++++++++++++++- .../components/assist_pipeline/test_select.py | 7 +- .../assist_pipeline/test_websocket.py | 130 ++++------ 11 files changed, 523 insertions(+), 179 deletions(-) diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index 6093d86e717..a56e535cc63 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -51,7 +51,7 @@ async def async_pipeline_from_audio_stream( tts_audio_output: str | None = None, ) -> None: """Create an audio pipeline from an audio stream.""" - pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id) + pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id) if pipeline is None: raise PipelineNotFound( "pipeline_not_found", f"Pipeline {pipeline_id} not found" diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index fd0df4edf92..faa1953cb94 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -24,7 +24,11 @@ from homeassistant.helpers.collection import ( StorageCollectionWebsocket, ) from homeassistant.helpers.storage import Store -from homeassistant.util import dt as dt_util, ulid as ulid_util +from homeassistant.util import ( + dt as dt_util, + language as language_util, + ulid as ulid_util, +) from homeassistant.util.limited_size_dict import LimitedSizeDict from .const import DOMAIN @@ -71,37 +75,109 @@ STORED_PIPELINE_RUNS = 10 SAVE_DELAY = 10 -async def async_get_pipeline( +async def _async_create_default_pipeline( + hass: HomeAssistant, pipeline_store: PipelineStorageCollection +) -> Pipeline: + """Create a default pipeline. + + The default pipeline will use the homeassistant conversation agent and the + default stt / tts engines. + """ + conversation_language = "en" + pipeline_language = "en" + pipeline_name = "Home Assistant" + stt_engine_id = None + stt_language = None + tts_engine_id = None + tts_language = None + tts_voice = None + + # Find a matching language supported by the Home Assistant conversation agent + conversation_languages = language_util.matches( + hass.config.language, + await conversation.async_get_conversation_languages( + hass, conversation.HOME_ASSISTANT_AGENT + ), + country=hass.config.country, + ) + if conversation_languages: + pipeline_language = hass.config.language + conversation_language = conversation_languages[0] + + if (stt_engine_id := stt.async_default_engine(hass)) is not None and ( + stt_engine := stt.async_get_speech_to_text_engine( + hass, + stt_engine_id, + ) + ): + stt_languages = language_util.matches( + pipeline_language, + stt_engine.supported_languages, + country=hass.config.country, + ) + if stt_languages: + stt_language = stt_languages[0] + else: + _LOGGER.debug( + "Speech to text engine '%s' does not support language '%s'", + stt_engine_id, + pipeline_language, + ) + stt_engine_id = None + + if (tts_engine_id := tts.async_default_engine(hass)) is not None and ( + tts_engine := tts.get_engine_instance( + hass, + tts_engine_id, + ) + ): + tts_languages = language_util.matches( + pipeline_language, + tts_engine.supported_languages, + country=hass.config.country, + ) + if tts_languages: + tts_language = tts_languages[0] + tts_voices = tts_engine.async_get_supported_voices(tts_language) + if tts_voices: + tts_voice = tts_voices[0].voice_id + else: + _LOGGER.debug( + "Text to speech engine '%s' does not support language '%s'", + tts_engine_id, + pipeline_language, + ) + tts_engine_id = None + + if stt_engine_id == "cloud" and tts_engine_id == "cloud": + pipeline_name = "Home Assistant Cloud" + + return await pipeline_store.async_create_item( + { + "conversation_engine": conversation.HOME_ASSISTANT_AGENT, + "conversation_language": conversation_language, + "language": hass.config.language, + "name": pipeline_name, + "stt_engine": stt_engine_id, + "stt_language": stt_language, + "tts_engine": tts_engine_id, + "tts_language": tts_language, + "tts_voice": tts_voice, + } + ) + + +@callback +def async_get_pipeline( hass: HomeAssistant, pipeline_id: str | None = None ) -> Pipeline | None: - """Get a pipeline by id or create one for a language.""" + """Get a pipeline by id or the preferred pipeline.""" pipeline_data: PipelineData = hass.data[DOMAIN] if pipeline_id is None: # A pipeline was not specified, use the preferred one pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item() - if pipeline_id is None: - # There's no preferred pipeline, construct a pipeline for the - # configured language - stt_engine = stt.async_default_provider(hass) - stt_language = hass.config.language if stt_engine else None - tts_engine = tts.async_default_engine(hass) - tts_language = hass.config.language if tts_engine else None - return await pipeline_data.pipeline_store.async_create_item( - { - "conversation_engine": None, - "conversation_language": None, - "language": hass.config.language, - "name": hass.config.language, - "stt_engine": stt_engine, - "stt_language": stt_language, - "tts_engine": tts_engine, - "tts_language": tts_language, - "tts_voice": None, - } - ) - return pipeline_data.pipeline_store.data.get(pipeline_id) @@ -635,7 +711,7 @@ class PipelinePreferred(CollectionError): class SerializedPipelineStorageCollection(SerializedStorageCollection): """Serialized pipeline storage collection.""" - preferred_item: str | None + preferred_item: str class PipelineStorageCollection( @@ -643,11 +719,13 @@ class PipelineStorageCollection( ): """Pipeline storage collection.""" - _preferred_item: str | None = None + _preferred_item: str async def _async_load_data(self) -> SerializedPipelineStorageCollection | None: """Load the data.""" if not (data := await super()._async_load_data()): + pipeline = await _async_create_default_pipeline(self.hass, self) + self._preferred_item = pipeline.id return data self._preferred_item = data["preferred_item"] @@ -671,8 +749,6 @@ class PipelineStorageCollection( def _create_item(self, item_id: str, data: dict) -> Pipeline: """Create an item from validated config.""" - if self._preferred_item is None: - self._preferred_item = item_id return Pipeline(id=item_id, **data) def _deserialize_item(self, data: dict) -> Pipeline: @@ -690,7 +766,7 @@ class PipelineStorageCollection( await super().async_delete_item(item_id) @callback - def async_get_preferred_item(self) -> str | None: + def async_get_preferred_item(self) -> str: """Get the id of the preferred item.""" return self._preferred_item diff --git a/homeassistant/components/assist_pipeline/websocket_api.py b/homeassistant/components/assist_pipeline/websocket_api.py index f66b4376482..6c1dbe3dbce 100644 --- a/homeassistant/components/assist_pipeline/websocket_api.py +++ b/homeassistant/components/assist_pipeline/websocket_api.py @@ -85,7 +85,7 @@ async def websocket_run( ) -> None: """Run a pipeline.""" pipeline_id = msg.get("pipeline") - pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id) + pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id) if pipeline is None: connection.send_error( msg["id"], diff --git a/homeassistant/components/conversation/__init__.py b/homeassistant/components/conversation/__init__.py index 9b3d5ef8dff..ac3057769f1 100644 --- a/homeassistant/components/conversation/__init__.py +++ b/homeassistant/components/conversation/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from collections.abc import Iterable from dataclasses import dataclass import logging import re @@ -117,14 +118,25 @@ def async_unset_agent( async def async_get_conversation_languages( - hass: HomeAssistant, + hass: HomeAssistant, agent_id: str | None = None ) -> set[str] | Literal["*"]: - """Return a set with the union of languages supported by conversation agents.""" + """Return languages supported by conversation agents. + + If an agent is specified, returns a set of languages supported by that agent. + If no agent is specified, return a set with the union of languages supported by + all conversation agents. + """ agent_manager = _get_agent_manager(hass) languages = set() - for agent_info in agent_manager.async_get_agent_info(): - agent = await agent_manager.async_get_agent(agent_info.id) + agent_ids: Iterable[str] + if agent_id is None: + agent_ids = iter(info.id for info in agent_manager.async_get_agent_info()) + else: + agent_ids = (agent_id,) + + for _agent_id in agent_ids: + agent = await agent_manager.async_get_agent(_agent_id) if agent.supported_languages == MATCH_ALL: return MATCH_ALL for language_tag in agent.supported_languages: diff --git a/tests/components/assist_pipeline/__init__.py b/tests/components/assist_pipeline/__init__.py index 6838f353c4b..7400fe32d70 100644 --- a/tests/components/assist_pipeline/__init__.py +++ b/tests/components/assist_pipeline/__init__.py @@ -1 +1,56 @@ """Tests for the Voice Assistant integration.""" + +MANY_LANGUAGES = [ + "ar", + "bg", + "bn", + "ca", + "cs", + "da", + "de", + "de-CH", + "el", + "en", + "es", + "fa", + "fi", + "fr", + "fr-CA", + "gl", + "gu", + "he", + "hi", + "hr", + "hu", + "id", + "is", + "it", + "ka", + "kn", + "lb", + "lt", + "lv", + "ml", + "mn", + "ms", + "nb", + "nl", + "pl", + "pt", + "pt-br", + "ro", + "ru", + "sk", + "sl", + "sr", + "sv", + "sw", + "te", + "tr", + "uk", + "ur", + "vi", + "zh-cn", + "zh-hk", + "zh-tw", +] diff --git a/tests/components/assist_pipeline/conftest.py b/tests/components/assist_pipeline/conftest.py index 744254f9954..1df52859ed9 100644 --- a/tests/components/assist_pipeline/conftest.py +++ b/tests/components/assist_pipeline/conftest.py @@ -11,9 +11,8 @@ from homeassistant.components import stt, tts from homeassistant.components.assist_pipeline import DOMAIN from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection from homeassistant.config_entries import ConfigEntry, ConfigFlow -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.setup import async_setup_component from tests.common import ( @@ -36,6 +35,8 @@ _TRANSCRIPT = "test transcript" class BaseProvider: """Mock STT provider.""" + _supported_languages = ["en-US"] + def __init__(self, text: str) -> None: """Init test provider.""" self.text = text @@ -44,7 +45,7 @@ class BaseProvider: @property def supported_languages(self) -> list[str]: """Return a list of supported languages.""" - return ["en-US"] + return self._supported_languages @property def supported_formats(self) -> list[stt.AudioFormats]: @@ -96,6 +97,13 @@ class MockTTSProvider(tts.Provider): """Mock TTS provider.""" name = "Test" + _supported_languages = ["en-US"] + _supported_voices = { + "en-US": [ + tts.Voice("james_earl_jones", "James Earl Jones"), + tts.Voice("fran_drescher", "Fran Drescher"), + ] + } @property def default_language(self) -> str: @@ -105,7 +113,12 @@ class MockTTSProvider(tts.Provider): @property def supported_languages(self) -> list[str]: """Return list of supported languages.""" - return ["en-US"] + return self._supported_languages + + @callback + def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None: + """Return a list of supported voices for a language.""" + return self._supported_voices.get(language) @property def supported_options(self) -> list[str]: @@ -119,19 +132,21 @@ class MockTTSProvider(tts.Provider): return ("mp3", b"") -class MockTTS(MockPlatform): +class MockTTSPlatform(MockPlatform): """A mock TTS platform.""" PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA - async def async_get_engine( - self, - hass: HomeAssistant, - config: ConfigType, - discovery_info: DiscoveryInfoType | None = None, - ) -> tts.Provider: - """Set up a mock speech component.""" - return MockTTSProvider() + def __init__(self, *, async_get_engine, **kwargs): + """Initialize the tts platform.""" + super().__init__(**kwargs) + self.async_get_engine = async_get_engine + + +@pytest.fixture +async def mock_tts_provider(hass) -> MockTTSProvider: + """Mock TTS provider.""" + return MockTTSProvider() @pytest.fixture @@ -169,10 +184,11 @@ def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]: @pytest.fixture -async def init_components( +async def init_supporting_components( hass: HomeAssistant, mock_stt_provider: MockSttProvider, mock_stt_provider_entity: MockSttProviderEntity, + mock_tts_provider: MockTTSProvider, config_flow_fixture, init_cache_dir_side_effect, # noqa: F811 mock_get_cache_files, # noqa: F811 @@ -210,7 +226,13 @@ async def init_components( async_unload_entry=async_unload_entry_init, ), ) - mock_platform(hass, "test.tts", MockTTS()) + mock_platform( + hass, + "test.tts", + MockTTSPlatform( + async_get_engine=AsyncMock(return_value=mock_tts_provider), + ), + ) mock_platform( hass, "test.stt", @@ -224,7 +246,6 @@ async def init_components( assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}}) assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}}) assert await async_setup_component(hass, "media_source", {}) - assert await async_setup_component(hass, "assist_pipeline", {}) config_entry = MockConfigEntry(domain="test") config_entry.add_to_hass(hass) @@ -232,6 +253,13 @@ async def init_components( await hass.async_block_till_done() +@pytest.fixture +async def init_components(hass: HomeAssistant, init_supporting_components): + """Initialize relevant components with empty configs.""" + + assert await async_setup_component(hass, "assist_pipeline", {}) + + @pytest.fixture def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection: """Return pipeline storage collection.""" diff --git a/tests/components/assist_pipeline/snapshots/test_init.ambr b/tests/components/assist_pipeline/snapshots/test_init.ambr index efa6434e784..1321fb97b17 100644 --- a/tests/components/assist_pipeline/snapshots/test_init.ambr +++ b/tests/components/assist_pipeline/snapshots/test_init.ambr @@ -4,7 +4,7 @@ dict({ 'data': dict({ 'language': 'en', - 'pipeline': 'en', + 'pipeline': 'Home Assistant', }), 'type': , }), @@ -34,7 +34,7 @@ 'data': dict({ 'engine': 'homeassistant', 'intent_input': 'test transcript', - 'language': None, + 'language': 'en', }), 'type': , }), @@ -64,18 +64,18 @@ dict({ 'data': dict({ 'engine': 'test', - 'language': 'en', + 'language': 'en-US', 'tts_input': "Sorry, I couldn't understand that", - 'voice': None, + 'voice': 'james_earl_jones', }), 'type': , }), dict({ 'data': dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", 'mime_type': 'audio/mpeg', - 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', }), }), 'type': , diff --git a/tests/components/assist_pipeline/snapshots/test_websocket.ambr b/tests/components/assist_pipeline/snapshots/test_websocket.ambr index 0abb00afdfb..f5a0a6dad92 100644 --- a/tests/components/assist_pipeline/snapshots/test_websocket.ambr +++ b/tests/components/assist_pipeline/snapshots/test_websocket.ambr @@ -2,7 +2,7 @@ # name: test_audio_pipeline dict({ 'language': 'en', - 'pipeline': 'en', + 'pipeline': 'Home Assistant', 'runner_data': dict({ 'stt_binary_handler_id': 1, 'timeout': 30, @@ -33,7 +33,7 @@ dict({ 'engine': 'homeassistant', 'intent_input': 'test transcript', - 'language': None, + 'language': 'en', }) # --- # name: test_audio_pipeline.4 @@ -61,24 +61,24 @@ # name: test_audio_pipeline.5 dict({ 'engine': 'test', - 'language': 'en', + 'language': 'en-US', 'tts_input': "Sorry, I couldn't understand that", - 'voice': None, + 'voice': 'james_earl_jones', }) # --- # name: test_audio_pipeline.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", 'mime_type': 'audio/mpeg', - 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', }), }) # --- # name: test_audio_pipeline_debug dict({ 'language': 'en', - 'pipeline': 'en', + 'pipeline': 'Home Assistant', 'runner_data': dict({ 'stt_binary_handler_id': 1, 'timeout': 30, @@ -109,7 +109,7 @@ dict({ 'engine': 'homeassistant', 'intent_input': 'test transcript', - 'language': None, + 'language': 'en', }) # --- # name: test_audio_pipeline_debug.4 @@ -137,24 +137,24 @@ # name: test_audio_pipeline_debug.5 dict({ 'engine': 'test', - 'language': 'en', + 'language': 'en-US', 'tts_input': "Sorry, I couldn't understand that", - 'voice': None, + 'voice': 'james_earl_jones', }) # --- # name: test_audio_pipeline_debug.6 dict({ 'tts_output': dict({ - 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en", + 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones", 'mime_type': 'audio/mpeg', - 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', + 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', }), }) # --- # name: test_intent_failed dict({ 'language': 'en', - 'pipeline': 'en', + 'pipeline': 'Home Assistant', 'runner_data': dict({ 'stt_binary_handler_id': None, 'timeout': 30, @@ -165,13 +165,13 @@ dict({ 'engine': 'homeassistant', 'intent_input': 'Are the lights on?', - 'language': None, + 'language': 'en', }) # --- # name: test_intent_timeout dict({ 'language': 'en', - 'pipeline': 'en', + 'pipeline': 'Home Assistant', 'runner_data': dict({ 'stt_binary_handler_id': None, 'timeout': 0.1, @@ -182,7 +182,7 @@ dict({ 'engine': 'homeassistant', 'intent_input': 'Are the lights on?', - 'language': None, + 'language': 'en', }) # --- # name: test_intent_timeout.2 @@ -217,7 +217,7 @@ # name: test_stt_stream_failed dict({ 'language': 'en', - 'pipeline': 'en', + 'pipeline': 'Home Assistant', 'runner_data': dict({ 'stt_binary_handler_id': 1, 'timeout': 30, @@ -240,7 +240,7 @@ # name: test_text_only_pipeline dict({ 'language': 'en', - 'pipeline': 'en', + 'pipeline': 'Home Assistant', 'runner_data': dict({ 'stt_binary_handler_id': None, 'timeout': 30, @@ -251,7 +251,7 @@ dict({ 'engine': 'homeassistant', 'intent_input': 'Are the lights on?', - 'language': None, + 'language': 'en', }) # --- # name: test_text_only_pipeline.2 @@ -285,7 +285,7 @@ # name: test_tts_failed dict({ 'language': 'en', - 'pipeline': 'en', + 'pipeline': 'Home Assistant', 'runner_data': dict({ 'stt_binary_handler_id': None, 'timeout': 30, @@ -295,8 +295,8 @@ # name: test_tts_failed.1 dict({ 'engine': 'test', - 'language': 'en', + 'language': 'en-US', 'tts_input': 'Lights are on.', - 'voice': None, + 'voice': 'james_earl_jones', }) # --- diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 47037869af6..6b2fa60102d 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -1,10 +1,14 @@ """Websocket tests for Voice Assistant integration.""" from typing import Any +from unittest.mock import AsyncMock, patch + +import pytest from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.pipeline import ( STORAGE_KEY, STORAGE_VERSION, + Pipeline, PipelineData, PipelineStorageCollection, async_get_pipeline, @@ -13,7 +17,10 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.storage import Store from homeassistant.setup import async_setup_component -from tests.common import flush_store +from . import MANY_LANGUAGES +from .conftest import MockSttPlatform, MockSttProvider, MockTTSPlatform, MockTTSProvider + +from tests.common import MockModule, flush_store, mock_integration, mock_platform async def test_load_datasets(hass: HomeAssistant, init_components) -> None: @@ -60,17 +67,17 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None: store1 = pipeline_data.pipeline_store for pipeline in pipelines: pipeline_ids.append((await store1.async_create_item(pipeline)).id) - assert len(store1.data) == 3 + assert len(store1.data) == 4 # 3 manually created plus a default pipeline assert store1.async_get_preferred_item() == list(store1.data)[0] await store1.async_delete_item(pipeline_ids[1]) - assert len(store1.data) == 2 + assert len(store1.data) == 3 store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY)) await flush_store(store1.store) await store2.async_load() - assert len(store2.data) == 2 + assert len(store2.data) == 3 assert store1.data is not store2.data assert store1.data == store2.data @@ -142,16 +149,221 @@ async def test_get_pipeline(hass: HomeAssistant) -> None: pipeline_data: PipelineData = hass.data[DOMAIN] store = pipeline_data.pipeline_store - assert len(store.data) == 0 - - # Test a pipeline is created - pipeline = await async_get_pipeline(hass, None) assert len(store.data) == 1 - # Test we get the same pipeline again - assert pipeline is await async_get_pipeline(hass, None) - assert len(store.data) == 1 + # Test we get the preferred pipeline if none is specified + pipeline = async_get_pipeline(hass, None) + assert pipeline.id == store.async_get_preferred_item() # Test getting a specific pipeline - assert pipeline is await async_get_pipeline(hass, pipeline.id) + assert pipeline is async_get_pipeline(hass, pipeline.id) + + +@pytest.mark.parametrize( + ("ha_language", "ha_country", "conv_language", "pipeline_language"), + [ + ("en", None, "en", "en"), + ("de", "de", "de", "de"), + ("de", "ch", "de-CH", "de"), + ("en", "us", "en", "en"), + ("en", "uk", "en", "en"), + ("pt", "pt", "pt", "pt"), + ("pt", "br", "pt-br", "pt"), + ], +) +async def test_default_pipeline_no_stt_tts( + hass: HomeAssistant, + ha_language: str, + ha_country: str | None, + conv_language: str, + pipeline_language: str, +) -> None: + """Test async_get_pipeline.""" + hass.config.country = ha_country + hass.config.language = ha_language + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store assert len(store.data) == 1 + + # Check the default pipeline + pipeline = async_get_pipeline(hass, None) + assert pipeline == Pipeline( + conversation_engine="homeassistant", + conversation_language=conv_language, + id=pipeline.id, + language=pipeline_language, + name="Home Assistant", + stt_engine=None, + stt_language=None, + tts_engine=None, + tts_language=None, + tts_voice=None, + ) + + +@pytest.mark.parametrize( + ( + "ha_language", + "ha_country", + "conv_language", + "pipeline_language", + "stt_language", + "tts_language", + ), + [ + ("en", None, "en", "en", "en", "en"), + ("de", "de", "de", "de", "de", "de"), + ("de", "ch", "de-CH", "de", "de-CH", "de-CH"), + ("en", "us", "en", "en", "en", "en"), + ("en", "uk", "en", "en", "en", "en"), + ("pt", "pt", "pt", "pt", "pt", "pt"), + ("pt", "br", "pt-br", "pt", "pt-br", "pt-br"), + ], +) +async def test_default_pipeline( + hass: HomeAssistant, + init_supporting_components, + mock_stt_provider: MockSttProvider, + mock_tts_provider: MockTTSProvider, + ha_language: str, + ha_country: str | None, + conv_language: str, + pipeline_language: str, + stt_language: str, + tts_language: str, +) -> None: + """Test async_get_pipeline.""" + hass.config.country = ha_country + hass.config.language = ha_language + + with patch.object( + mock_stt_provider, "_supported_languages", MANY_LANGUAGES + ), patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES): + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store + assert len(store.data) == 1 + + # Check the default pipeline + pipeline = async_get_pipeline(hass, None) + assert pipeline == Pipeline( + conversation_engine="homeassistant", + conversation_language=conv_language, + id=pipeline.id, + language=pipeline_language, + name="Home Assistant", + stt_engine="test", + stt_language=stt_language, + tts_engine="test", + tts_language=tts_language, + tts_voice=None, + ) + + +async def test_default_pipeline_unsupported_stt_language( + hass: HomeAssistant, + init_supporting_components, + mock_stt_provider: MockSttProvider, +) -> None: + """Test async_get_pipeline.""" + with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]): + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store + assert len(store.data) == 1 + + # Check the default pipeline + pipeline = async_get_pipeline(hass, None) + assert pipeline == Pipeline( + conversation_engine="homeassistant", + conversation_language="en", + id=pipeline.id, + language="en", + name="Home Assistant", + stt_engine=None, + stt_language=None, + tts_engine="test", + tts_language="en-US", + tts_voice="james_earl_jones", + ) + + +async def test_default_pipeline_unsupported_tts_language( + hass: HomeAssistant, + init_supporting_components, + mock_tts_provider: MockTTSProvider, +) -> None: + """Test async_get_pipeline.""" + with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]): + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store + assert len(store.data) == 1 + + # Check the default pipeline + pipeline = async_get_pipeline(hass, None) + assert pipeline == Pipeline( + conversation_engine="homeassistant", + conversation_language="en", + id=pipeline.id, + language="en", + name="Home Assistant", + stt_engine="test", + stt_language="en-US", + tts_engine=None, + tts_language=None, + tts_voice=None, + ) + + +async def test_default_pipeline_cloud( + hass: HomeAssistant, + mock_stt_provider: MockSttProvider, + mock_tts_provider: MockTTSProvider, +) -> None: + """Test async_get_pipeline.""" + + mock_integration(hass, MockModule("cloud")) + mock_platform( + hass, + "cloud.tts", + MockTTSPlatform( + async_get_engine=AsyncMock(return_value=mock_tts_provider), + ), + ) + mock_platform( + hass, + "cloud.stt", + MockSttPlatform( + async_get_engine=AsyncMock(return_value=mock_stt_provider), + ), + ) + mock_platform(hass, "test.config_flow") + + assert await async_setup_component(hass, "tts", {"tts": {"platform": "cloud"}}) + assert await async_setup_component(hass, "stt", {"stt": {"platform": "cloud"}}) + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store + assert len(store.data) == 1 + + # Check the default pipeline + pipeline = async_get_pipeline(hass, None) + assert pipeline == Pipeline( + conversation_engine="homeassistant", + conversation_language="en", + id=pipeline.id, + language="en", + name="Home Assistant Cloud", + stt_engine="cloud", + stt_language="en-US", + tts_engine="cloud", + tts_language="en-US", + tts_voice="james_earl_jones", + ) diff --git a/tests/components/assist_pipeline/test_select.py b/tests/components/assist_pipeline/test_select.py index b7fc232494b..30874e7b756 100644 --- a/tests/components/assist_pipeline/test_select.py +++ b/tests/components/assist_pipeline/test_select.py @@ -92,6 +92,7 @@ async def test_select_entity_changing_pipelines( assert state.state == "preferred" assert state.attributes["options"] == [ "preferred", + "Home Assistant", pipeline_1.name, pipeline_2.name, ] @@ -122,4 +123,8 @@ async def test_select_entity_changing_pipelines( state = hass.states.get("select.assist_pipeline_test_pipeline") assert state.state == "preferred" - assert state.attributes["options"] == ["preferred", pipeline_1.name] + assert state.attributes["options"] == [ + "preferred", + "Home Assistant", + pipeline_1.name, + ] diff --git a/tests/components/assist_pipeline/test_websocket.py b/tests/components/assist_pipeline/test_websocket.py index 2fc19a79005..827d7b85113 100644 --- a/tests/components/assist_pipeline/test_websocket.py +++ b/tests/components/assist_pipeline/test_websocket.py @@ -610,7 +610,7 @@ async def test_add_pipeline( "tts_voice": "Arnold Schwarzenegger", } - assert len(pipeline_store.data) == 1 + assert len(pipeline_store.data) == 2 pipeline = pipeline_store.data[msg["result"]["id"]] assert pipeline == Pipeline( conversation_engine="test_conversation_engine", @@ -643,6 +643,7 @@ async def test_add_pipeline_missing_language( client = await hass_ws_client(hass) pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_store = pipeline_data.pipeline_store + assert len(pipeline_store.data) == 1 await client.send_json_auto_id( { @@ -660,7 +661,7 @@ async def test_add_pipeline_missing_language( ) msg = await client.receive_json() assert not msg["success"] - assert len(pipeline_store.data) == 0 + assert len(pipeline_store.data) == 1 await client.send_json_auto_id( { @@ -678,7 +679,7 @@ async def test_add_pipeline_missing_language( ) msg = await client.receive_json() assert not msg["success"] - assert len(pipeline_store.data) == 0 + assert len(pipeline_store.data) == 1 async def test_delete_pipeline( @@ -725,7 +726,16 @@ async def test_delete_pipeline( assert msg["success"] pipeline_id_2 = msg["result"]["id"] - assert len(pipeline_store.data) == 2 + assert len(pipeline_store.data) == 3 + + await client.send_json_auto_id( + { + "type": "assist_pipeline/pipeline/set_preferred", + "pipeline_id": pipeline_id_1, + } + ) + msg = await client.receive_json() + assert msg["success"] await client.send_json_auto_id( { @@ -748,7 +758,7 @@ async def test_delete_pipeline( ) msg = await client.receive_json() assert msg["success"] - assert len(pipeline_store.data) == 1 + assert len(pipeline_store.data) == 2 await client.send_json_auto_id( { @@ -778,10 +788,18 @@ async def test_get_pipeline( } ) msg = await client.receive_json() - assert not msg["success"] - assert msg["error"] == { - "code": "not_found", - "message": "Unable to find pipeline_id None", + assert msg["success"] + assert msg["result"] == { + "conversation_engine": "homeassistant", + "conversation_language": "en", + "id": ANY, + "language": "en", + "name": "Home Assistant", + "stt_engine": "test", + "stt_language": "en-US", + "tts_engine": "test", + "tts_language": "en-US", + "tts_voice": "james_earl_jones", } await client.send_json_auto_id( @@ -814,27 +832,7 @@ async def test_get_pipeline( msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] - assert len(pipeline_store.data) == 1 - - await client.send_json_auto_id( - { - "type": "assist_pipeline/pipeline/get", - } - ) - msg = await client.receive_json() - assert msg["success"] - assert msg["result"] == { - "conversation_engine": "test_conversation_engine", - "conversation_language": "test_language", - "id": pipeline_id, - "language": "test_language", - "name": "test_name", - "stt_engine": "test_stt_engine", - "stt_language": "test_language", - "tts_engine": "test_tts_engine", - "tts_language": "test_language", - "tts_voice": "Arnold Schwarzenegger", - } + assert len(pipeline_store.data) == 2 await client.send_json_auto_id( { @@ -863,31 +861,7 @@ async def test_list_pipelines( ) -> None: """Test we can list pipelines.""" client = await hass_ws_client(hass) - pipeline_data: PipelineData = hass.data[DOMAIN] - pipeline_store = pipeline_data.pipeline_store - - await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) - msg = await client.receive_json() - assert msg["success"] - assert msg["result"] == {"pipelines": [], "preferred_pipeline": None} - - await client.send_json_auto_id( - { - "type": "assist_pipeline/pipeline/create", - "conversation_engine": "test_conversation_engine", - "conversation_language": "test_language", - "language": "test_language", - "name": "test_name", - "stt_engine": "test_stt_engine", - "stt_language": "test_language", - "tts_engine": "test_tts_engine", - "tts_language": "test_language", - "tts_voice": "Arnold Schwarzenegger", - } - ) - msg = await client.receive_json() - assert msg["success"] - assert len(pipeline_store.data) == 1 + hass.data[DOMAIN] await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) msg = await client.receive_json() @@ -895,16 +869,16 @@ async def test_list_pipelines( assert msg["result"] == { "pipelines": [ { - "conversation_engine": "test_conversation_engine", - "conversation_language": "test_language", + "conversation_engine": "homeassistant", + "conversation_language": "en", "id": ANY, - "language": "test_language", - "name": "test_name", - "stt_engine": "test_stt_engine", - "stt_language": "test_language", - "tts_engine": "test_tts_engine", - "tts_language": "test_language", - "tts_voice": "Arnold Schwarzenegger", + "language": "en", + "name": "Home Assistant", + "stt_engine": "test", + "stt_language": "en-US", + "tts_engine": "test", + "tts_language": "en-US", + "tts_voice": "james_earl_jones", } ], "preferred_pipeline": ANY, @@ -958,7 +932,7 @@ async def test_update_pipeline( msg = await client.receive_json() assert msg["success"] pipeline_id = msg["result"]["id"] - assert len(pipeline_store.data) == 1 + assert len(pipeline_store.data) == 2 await client.send_json_auto_id( { @@ -990,7 +964,7 @@ async def test_update_pipeline( "tts_voice": "new_tts_voice", } - assert len(pipeline_store.data) == 1 + assert len(pipeline_store.data) == 2 pipeline = pipeline_store.data[pipeline_id] assert pipeline == Pipeline( conversation_engine="new_conversation_engine", @@ -1076,36 +1050,18 @@ async def test_set_preferred_pipeline( assert msg["success"] pipeline_id_1 = msg["result"]["id"] - await client.send_json_auto_id( - { - "type": "assist_pipeline/pipeline/create", - "conversation_engine": "test_conversation_engine", - "conversation_language": "test_language", - "language": "test_language", - "name": "test_name", - "stt_engine": "test_stt_engine", - "stt_language": "test_language", - "tts_engine": "test_tts_engine", - "tts_language": "test_language", - "tts_voice": "Arnold Schwarzenegger", - } - ) - msg = await client.receive_json() - assert msg["success"] - pipeline_id_2 = msg["result"]["id"] - - assert pipeline_store.async_get_preferred_item() == pipeline_id_1 + assert pipeline_store.async_get_preferred_item() != pipeline_id_1 await client.send_json_auto_id( { "type": "assist_pipeline/pipeline/set_preferred", - "pipeline_id": pipeline_id_2, + "pipeline_id": pipeline_id_1, } ) msg = await client.receive_json() assert msg["success"] - assert pipeline_store.async_get_preferred_item() == pipeline_id_2 + assert pipeline_store.async_get_preferred_item() == pipeline_id_1 async def test_set_preferred_pipeline_wrong_id(