Create a default assist pipeline on start (#91947)

* Create a default assist pipeline on start

* Minor adjustments

* Address review comments

* Remove tts.async_get_agent

* Fix bugs, improve test coverage
This commit is contained in:
Erik Montnemery 2023-04-24 20:00:52 +02:00 committed by GitHub
parent 4b619f7251
commit b601fb17d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 523 additions and 179 deletions

View File

@ -51,7 +51,7 @@ async def async_pipeline_from_audio_stream(
tts_audio_output: str | None = None, tts_audio_output: str | None = None,
) -> None: ) -> None:
"""Create an audio pipeline from an audio stream.""" """Create an audio pipeline from an audio stream."""
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id) pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None: if pipeline is None:
raise PipelineNotFound( raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found" "pipeline_not_found", f"Pipeline {pipeline_id} not found"

View File

@ -24,7 +24,11 @@ from homeassistant.helpers.collection import (
StorageCollectionWebsocket, StorageCollectionWebsocket,
) )
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util, ulid as ulid_util from homeassistant.util import (
dt as dt_util,
language as language_util,
ulid as ulid_util,
)
from homeassistant.util.limited_size_dict import LimitedSizeDict from homeassistant.util.limited_size_dict import LimitedSizeDict
from .const import DOMAIN from .const import DOMAIN
@ -71,37 +75,109 @@ STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10 SAVE_DELAY = 10
async def async_get_pipeline( async def _async_create_default_pipeline(
hass: HomeAssistant, pipeline_store: PipelineStorageCollection
) -> Pipeline:
"""Create a default pipeline.
The default pipeline will use the homeassistant conversation agent and the
default stt / tts engines.
"""
conversation_language = "en"
pipeline_language = "en"
pipeline_name = "Home Assistant"
stt_engine_id = None
stt_language = None
tts_engine_id = None
tts_language = None
tts_voice = None
# Find a matching language supported by the Home Assistant conversation agent
conversation_languages = language_util.matches(
hass.config.language,
await conversation.async_get_conversation_languages(
hass, conversation.HOME_ASSISTANT_AGENT
),
country=hass.config.country,
)
if conversation_languages:
pipeline_language = hass.config.language
conversation_language = conversation_languages[0]
if (stt_engine_id := stt.async_default_engine(hass)) is not None and (
stt_engine := stt.async_get_speech_to_text_engine(
hass,
stt_engine_id,
)
):
stt_languages = language_util.matches(
pipeline_language,
stt_engine.supported_languages,
country=hass.config.country,
)
if stt_languages:
stt_language = stt_languages[0]
else:
_LOGGER.debug(
"Speech to text engine '%s' does not support language '%s'",
stt_engine_id,
pipeline_language,
)
stt_engine_id = None
if (tts_engine_id := tts.async_default_engine(hass)) is not None and (
tts_engine := tts.get_engine_instance(
hass,
tts_engine_id,
)
):
tts_languages = language_util.matches(
pipeline_language,
tts_engine.supported_languages,
country=hass.config.country,
)
if tts_languages:
tts_language = tts_languages[0]
tts_voices = tts_engine.async_get_supported_voices(tts_language)
if tts_voices:
tts_voice = tts_voices[0].voice_id
else:
_LOGGER.debug(
"Text to speech engine '%s' does not support language '%s'",
tts_engine_id,
pipeline_language,
)
tts_engine_id = None
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
pipeline_name = "Home Assistant Cloud"
return await pipeline_store.async_create_item(
{
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_language": conversation_language,
"language": hass.config.language,
"name": pipeline_name,
"stt_engine": stt_engine_id,
"stt_language": stt_language,
"tts_engine": tts_engine_id,
"tts_language": tts_language,
"tts_voice": tts_voice,
}
)
@callback
def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None hass: HomeAssistant, pipeline_id: str | None = None
) -> Pipeline | None: ) -> Pipeline | None:
"""Get a pipeline by id or create one for a language.""" """Get a pipeline by id or the preferred pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
if pipeline_id is None: if pipeline_id is None:
# A pipeline was not specified, use the preferred one # A pipeline was not specified, use the preferred one
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item() pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
if pipeline_id is None:
# There's no preferred pipeline, construct a pipeline for the
# configured language
stt_engine = stt.async_default_provider(hass)
stt_language = hass.config.language if stt_engine else None
tts_engine = tts.async_default_engine(hass)
tts_language = hass.config.language if tts_engine else None
return await pipeline_data.pipeline_store.async_create_item(
{
"conversation_engine": None,
"conversation_language": None,
"language": hass.config.language,
"name": hass.config.language,
"stt_engine": stt_engine,
"stt_language": stt_language,
"tts_engine": tts_engine,
"tts_language": tts_language,
"tts_voice": None,
}
)
return pipeline_data.pipeline_store.data.get(pipeline_id) return pipeline_data.pipeline_store.data.get(pipeline_id)
@ -635,7 +711,7 @@ class PipelinePreferred(CollectionError):
class SerializedPipelineStorageCollection(SerializedStorageCollection): class SerializedPipelineStorageCollection(SerializedStorageCollection):
"""Serialized pipeline storage collection.""" """Serialized pipeline storage collection."""
preferred_item: str | None preferred_item: str
class PipelineStorageCollection( class PipelineStorageCollection(
@ -643,11 +719,13 @@ class PipelineStorageCollection(
): ):
"""Pipeline storage collection.""" """Pipeline storage collection."""
_preferred_item: str | None = None _preferred_item: str
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None: async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
"""Load the data.""" """Load the data."""
if not (data := await super()._async_load_data()): if not (data := await super()._async_load_data()):
pipeline = await _async_create_default_pipeline(self.hass, self)
self._preferred_item = pipeline.id
return data return data
self._preferred_item = data["preferred_item"] self._preferred_item = data["preferred_item"]
@ -671,8 +749,6 @@ class PipelineStorageCollection(
def _create_item(self, item_id: str, data: dict) -> Pipeline: def _create_item(self, item_id: str, data: dict) -> Pipeline:
"""Create an item from validated config.""" """Create an item from validated config."""
if self._preferred_item is None:
self._preferred_item = item_id
return Pipeline(id=item_id, **data) return Pipeline(id=item_id, **data)
def _deserialize_item(self, data: dict) -> Pipeline: def _deserialize_item(self, data: dict) -> Pipeline:
@ -690,7 +766,7 @@ class PipelineStorageCollection(
await super().async_delete_item(item_id) await super().async_delete_item(item_id)
@callback @callback
def async_get_preferred_item(self) -> str | None: def async_get_preferred_item(self) -> str:
"""Get the id of the preferred item.""" """Get the id of the preferred item."""
return self._preferred_item return self._preferred_item

View File

@ -85,7 +85,7 @@ async def websocket_run(
) -> None: ) -> None:
"""Run a pipeline.""" """Run a pipeline."""
pipeline_id = msg.get("pipeline") pipeline_id = msg.get("pipeline")
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id) pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None: if pipeline is None:
connection.send_error( connection.send_error(
msg["id"], msg["id"],

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
import re import re
@ -117,14 +118,25 @@ def async_unset_agent(
async def async_get_conversation_languages( async def async_get_conversation_languages(
hass: HomeAssistant, hass: HomeAssistant, agent_id: str | None = None
) -> set[str] | Literal["*"]: ) -> set[str] | Literal["*"]:
"""Return a set with the union of languages supported by conversation agents.""" """Return languages supported by conversation agents.
If an agent is specified, returns a set of languages supported by that agent.
If no agent is specified, return a set with the union of languages supported by
all conversation agents.
"""
agent_manager = _get_agent_manager(hass) agent_manager = _get_agent_manager(hass)
languages = set() languages = set()
for agent_info in agent_manager.async_get_agent_info(): agent_ids: Iterable[str]
agent = await agent_manager.async_get_agent(agent_info.id) if agent_id is None:
agent_ids = iter(info.id for info in agent_manager.async_get_agent_info())
else:
agent_ids = (agent_id,)
for _agent_id in agent_ids:
agent = await agent_manager.async_get_agent(_agent_id)
if agent.supported_languages == MATCH_ALL: if agent.supported_languages == MATCH_ALL:
return MATCH_ALL return MATCH_ALL
for language_tag in agent.supported_languages: for language_tag in agent.supported_languages:

View File

@ -1 +1,56 @@
"""Tests for the Voice Assistant integration.""" """Tests for the Voice Assistant integration."""
MANY_LANGUAGES = [
"ar",
"bg",
"bn",
"ca",
"cs",
"da",
"de",
"de-CH",
"el",
"en",
"es",
"fa",
"fi",
"fr",
"fr-CA",
"gl",
"gu",
"he",
"hi",
"hr",
"hu",
"id",
"is",
"it",
"ka",
"kn",
"lb",
"lt",
"lv",
"ml",
"mn",
"ms",
"nb",
"nl",
"pl",
"pt",
"pt-br",
"ro",
"ru",
"sk",
"sl",
"sr",
"sv",
"sw",
"te",
"tr",
"uk",
"ur",
"vi",
"zh-cn",
"zh-hk",
"zh-tw",
]

View File

@ -11,9 +11,8 @@ from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.config_entries import ConfigEntry, ConfigFlow from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import ( from tests.common import (
@ -36,6 +35,8 @@ _TRANSCRIPT = "test transcript"
class BaseProvider: class BaseProvider:
"""Mock STT provider.""" """Mock STT provider."""
_supported_languages = ["en-US"]
def __init__(self, text: str) -> None: def __init__(self, text: str) -> None:
"""Init test provider.""" """Init test provider."""
self.text = text self.text = text
@ -44,7 +45,7 @@ class BaseProvider:
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
"""Return a list of supported languages.""" """Return a list of supported languages."""
return ["en-US"] return self._supported_languages
@property @property
def supported_formats(self) -> list[stt.AudioFormats]: def supported_formats(self) -> list[stt.AudioFormats]:
@ -96,6 +97,13 @@ class MockTTSProvider(tts.Provider):
"""Mock TTS provider.""" """Mock TTS provider."""
name = "Test" name = "Test"
_supported_languages = ["en-US"]
_supported_voices = {
"en-US": [
tts.Voice("james_earl_jones", "James Earl Jones"),
tts.Voice("fran_drescher", "Fran Drescher"),
]
}
@property @property
def default_language(self) -> str: def default_language(self) -> str:
@ -105,7 +113,12 @@ class MockTTSProvider(tts.Provider):
@property @property
def supported_languages(self) -> list[str]: def supported_languages(self) -> list[str]:
"""Return list of supported languages.""" """Return list of supported languages."""
return ["en-US"] return self._supported_languages
@callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
"""Return a list of supported voices for a language."""
return self._supported_voices.get(language)
@property @property
def supported_options(self) -> list[str]: def supported_options(self) -> list[str]:
@ -119,19 +132,21 @@ class MockTTSProvider(tts.Provider):
return ("mp3", b"") return ("mp3", b"")
class MockTTS(MockPlatform): class MockTTSPlatform(MockPlatform):
"""A mock TTS platform.""" """A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
async def async_get_engine( def __init__(self, *, async_get_engine, **kwargs):
self, """Initialize the tts platform."""
hass: HomeAssistant, super().__init__(**kwargs)
config: ConfigType, self.async_get_engine = async_get_engine
discovery_info: DiscoveryInfoType | None = None,
) -> tts.Provider:
"""Set up a mock speech component.""" @pytest.fixture
return MockTTSProvider() async def mock_tts_provider(hass) -> MockTTSProvider:
"""Mock TTS provider."""
return MockTTSProvider()
@pytest.fixture @pytest.fixture
@ -169,10 +184,11 @@ def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
@pytest.fixture @pytest.fixture
async def init_components( async def init_supporting_components(
hass: HomeAssistant, hass: HomeAssistant,
mock_stt_provider: MockSttProvider, mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity, mock_stt_provider_entity: MockSttProviderEntity,
mock_tts_provider: MockTTSProvider,
config_flow_fixture, config_flow_fixture,
init_cache_dir_side_effect, # noqa: F811 init_cache_dir_side_effect, # noqa: F811
mock_get_cache_files, # noqa: F811 mock_get_cache_files, # noqa: F811
@ -210,7 +226,13 @@ async def init_components(
async_unload_entry=async_unload_entry_init, async_unload_entry=async_unload_entry_init,
), ),
) )
mock_platform(hass, "test.tts", MockTTS()) mock_platform(
hass,
"test.tts",
MockTTSPlatform(
async_get_engine=AsyncMock(return_value=mock_tts_provider),
),
)
mock_platform( mock_platform(
hass, hass,
"test.stt", "test.stt",
@ -224,7 +246,6 @@ async def init_components(
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}}) assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}}) assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {}) assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "assist_pipeline", {})
config_entry = MockConfigEntry(domain="test") config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
@ -232,6 +253,13 @@ async def init_components(
await hass.async_block_till_done() await hass.async_block_till_done()
@pytest.fixture
async def init_components(hass: HomeAssistant, init_supporting_components):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "assist_pipeline", {})
@pytest.fixture @pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection: def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
"""Return pipeline storage collection.""" """Return pipeline storage collection."""

View File

@ -4,7 +4,7 @@
dict({ dict({
'data': dict({ 'data': dict({
'language': 'en', 'language': 'en',
'pipeline': 'en', 'pipeline': 'Home Assistant',
}), }),
'type': <PipelineEventType.RUN_START: 'run-start'>, 'type': <PipelineEventType.RUN_START: 'run-start'>,
}), }),
@ -34,7 +34,7 @@
'data': dict({ 'data': dict({
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': None, 'language': 'en',
}), }),
'type': <PipelineEventType.INTENT_START: 'intent-start'>, 'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}), }),
@ -64,18 +64,18 @@
dict({ dict({
'data': dict({ 'data': dict({
'engine': 'test', 'engine': 'test',
'language': 'en', 'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that", 'tts_input': "Sorry, I couldn't understand that",
'voice': None, 'voice': 'james_earl_jones',
}), }),
'type': <PipelineEventType.TTS_START: 'tts-start'>, 'type': <PipelineEventType.TTS_START: 'tts-start'>,
}), }),
dict({ dict({
'data': dict({ 'data': dict({
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}), }),
}), }),
'type': <PipelineEventType.TTS_END: 'tts-end'>, 'type': <PipelineEventType.TTS_END: 'tts-end'>,

View File

@ -2,7 +2,7 @@
# name: test_audio_pipeline # name: test_audio_pipeline
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'en', 'pipeline': 'Home Assistant',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -33,7 +33,7 @@
dict({ dict({
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': None, 'language': 'en',
}) })
# --- # ---
# name: test_audio_pipeline.4 # name: test_audio_pipeline.4
@ -61,24 +61,24 @@
# name: test_audio_pipeline.5 # name: test_audio_pipeline.5
dict({ dict({
'engine': 'test', 'engine': 'test',
'language': 'en', 'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that", 'tts_input': "Sorry, I couldn't understand that",
'voice': None, 'voice': 'james_earl_jones',
}) })
# --- # ---
# name: test_audio_pipeline.6 # name: test_audio_pipeline.6
dict({ dict({
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}), }),
}) })
# --- # ---
# name: test_audio_pipeline_debug # name: test_audio_pipeline_debug
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'en', 'pipeline': 'Home Assistant',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -109,7 +109,7 @@
dict({ dict({
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'test transcript', 'intent_input': 'test transcript',
'language': None, 'language': 'en',
}) })
# --- # ---
# name: test_audio_pipeline_debug.4 # name: test_audio_pipeline_debug.4
@ -137,24 +137,24 @@
# name: test_audio_pipeline_debug.5 # name: test_audio_pipeline_debug.5
dict({ dict({
'engine': 'test', 'engine': 'test',
'language': 'en', 'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that", 'tts_input': "Sorry, I couldn't understand that",
'voice': None, 'voice': 'james_earl_jones',
}) })
# --- # ---
# name: test_audio_pipeline_debug.6 # name: test_audio_pipeline_debug.6
dict({ dict({
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3', 'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}), }),
}) })
# --- # ---
# name: test_intent_failed # name: test_intent_failed
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'en', 'pipeline': 'Home Assistant',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 30, 'timeout': 30,
@ -165,13 +165,13 @@
dict({ dict({
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
'language': None, 'language': 'en',
}) })
# --- # ---
# name: test_intent_timeout # name: test_intent_timeout
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'en', 'pipeline': 'Home Assistant',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 0.1, 'timeout': 0.1,
@ -182,7 +182,7 @@
dict({ dict({
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
'language': None, 'language': 'en',
}) })
# --- # ---
# name: test_intent_timeout.2 # name: test_intent_timeout.2
@ -217,7 +217,7 @@
# name: test_stt_stream_failed # name: test_stt_stream_failed
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'en', 'pipeline': 'Home Assistant',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': 1, 'stt_binary_handler_id': 1,
'timeout': 30, 'timeout': 30,
@ -240,7 +240,7 @@
# name: test_text_only_pipeline # name: test_text_only_pipeline
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'en', 'pipeline': 'Home Assistant',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 30, 'timeout': 30,
@ -251,7 +251,7 @@
dict({ dict({
'engine': 'homeassistant', 'engine': 'homeassistant',
'intent_input': 'Are the lights on?', 'intent_input': 'Are the lights on?',
'language': None, 'language': 'en',
}) })
# --- # ---
# name: test_text_only_pipeline.2 # name: test_text_only_pipeline.2
@ -285,7 +285,7 @@
# name: test_tts_failed # name: test_tts_failed
dict({ dict({
'language': 'en', 'language': 'en',
'pipeline': 'en', 'pipeline': 'Home Assistant',
'runner_data': dict({ 'runner_data': dict({
'stt_binary_handler_id': None, 'stt_binary_handler_id': None,
'timeout': 30, 'timeout': 30,
@ -295,8 +295,8 @@
# name: test_tts_failed.1 # name: test_tts_failed.1
dict({ dict({
'engine': 'test', 'engine': 'test',
'language': 'en', 'language': 'en-US',
'tts_input': 'Lights are on.', 'tts_input': 'Lights are on.',
'voice': None, 'voice': 'james_earl_jones',
}) })
# --- # ---

View File

@ -1,10 +1,14 @@
"""Websocket tests for Voice Assistant integration.""" """Websocket tests for Voice Assistant integration."""
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.components.assist_pipeline.const import DOMAIN from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import ( from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_KEY, STORAGE_KEY,
STORAGE_VERSION, STORAGE_VERSION,
Pipeline,
PipelineData, PipelineData,
PipelineStorageCollection, PipelineStorageCollection,
async_get_pipeline, async_get_pipeline,
@ -13,7 +17,10 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import flush_store from . import MANY_LANGUAGES
from .conftest import MockSttPlatform, MockSttProvider, MockTTSPlatform, MockTTSProvider
from tests.common import MockModule, flush_store, mock_integration, mock_platform
async def test_load_datasets(hass: HomeAssistant, init_components) -> None: async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
@ -60,17 +67,17 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
store1 = pipeline_data.pipeline_store store1 = pipeline_data.pipeline_store
for pipeline in pipelines: for pipeline in pipelines:
pipeline_ids.append((await store1.async_create_item(pipeline)).id) pipeline_ids.append((await store1.async_create_item(pipeline)).id)
assert len(store1.data) == 3 assert len(store1.data) == 4 # 3 manually created plus a default pipeline
assert store1.async_get_preferred_item() == list(store1.data)[0] assert store1.async_get_preferred_item() == list(store1.data)[0]
await store1.async_delete_item(pipeline_ids[1]) await store1.async_delete_item(pipeline_ids[1])
assert len(store1.data) == 2 assert len(store1.data) == 3
store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY)) store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY))
await flush_store(store1.store) await flush_store(store1.store)
await store2.async_load() await store2.async_load()
assert len(store2.data) == 2 assert len(store2.data) == 3
assert store1.data is not store2.data assert store1.data is not store2.data
assert store1.data == store2.data assert store1.data == store2.data
@ -142,16 +149,221 @@ async def test_get_pipeline(hass: HomeAssistant) -> None:
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store store = pipeline_data.pipeline_store
assert len(store.data) == 0
# Test a pipeline is created
pipeline = await async_get_pipeline(hass, None)
assert len(store.data) == 1 assert len(store.data) == 1
# Test we get the same pipeline again # Test we get the preferred pipeline if none is specified
assert pipeline is await async_get_pipeline(hass, None) pipeline = async_get_pipeline(hass, None)
assert len(store.data) == 1 assert pipeline.id == store.async_get_preferred_item()
# Test getting a specific pipeline # Test getting a specific pipeline
assert pipeline is await async_get_pipeline(hass, pipeline.id) assert pipeline is async_get_pipeline(hass, pipeline.id)
@pytest.mark.parametrize(
("ha_language", "ha_country", "conv_language", "pipeline_language"),
[
("en", None, "en", "en"),
("de", "de", "de", "de"),
("de", "ch", "de-CH", "de"),
("en", "us", "en", "en"),
("en", "uk", "en", "en"),
("pt", "pt", "pt", "pt"),
("pt", "br", "pt-br", "pt"),
],
)
async def test_default_pipeline_no_stt_tts(
hass: HomeAssistant,
ha_language: str,
ha_country: str | None,
conv_language: str,
pipeline_language: str,
) -> None:
"""Test async_get_pipeline."""
hass.config.country = ha_country
hass.config.language = ha_language
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1 assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
)
@pytest.mark.parametrize(
(
"ha_language",
"ha_country",
"conv_language",
"pipeline_language",
"stt_language",
"tts_language",
),
[
("en", None, "en", "en", "en", "en"),
("de", "de", "de", "de", "de", "de"),
("de", "ch", "de-CH", "de", "de-CH", "de-CH"),
("en", "us", "en", "en", "en", "en"),
("en", "uk", "en", "en", "en", "en"),
("pt", "pt", "pt", "pt", "pt", "pt"),
("pt", "br", "pt-br", "pt", "pt-br", "pt-br"),
],
)
async def test_default_pipeline(
hass: HomeAssistant,
init_supporting_components,
mock_stt_provider: MockSttProvider,
mock_tts_provider: MockTTSProvider,
ha_language: str,
ha_country: str | None,
conv_language: str,
pipeline_language: str,
stt_language: str,
tts_language: str,
) -> None:
"""Test async_get_pipeline."""
hass.config.country = ha_country
hass.config.language = ha_language
with patch.object(
mock_stt_provider, "_supported_languages", MANY_LANGUAGES
), patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
name="Home Assistant",
stt_engine="test",
stt_language=stt_language,
tts_engine="test",
tts_language=tts_language,
tts_voice=None,
)
async def test_default_pipeline_unsupported_stt_language(
hass: HomeAssistant,
init_supporting_components,
mock_stt_provider: MockSttProvider,
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine="test",
tts_language="en-US",
tts_voice="james_earl_jones",
)
async def test_default_pipeline_unsupported_tts_language(
hass: HomeAssistant,
init_supporting_components,
mock_tts_provider: MockTTSProvider,
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant",
stt_engine="test",
stt_language="en-US",
tts_engine=None,
tts_language=None,
tts_voice=None,
)
async def test_default_pipeline_cloud(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_tts_provider: MockTTSProvider,
) -> None:
"""Test async_get_pipeline."""
mock_integration(hass, MockModule("cloud"))
mock_platform(
hass,
"cloud.tts",
MockTTSPlatform(
async_get_engine=AsyncMock(return_value=mock_tts_provider),
),
)
mock_platform(
hass,
"cloud.stt",
MockSttPlatform(
async_get_engine=AsyncMock(return_value=mock_stt_provider),
),
)
mock_platform(hass, "test.config_flow")
assert await async_setup_component(hass, "tts", {"tts": {"platform": "cloud"}})
assert await async_setup_component(hass, "stt", {"stt": {"platform": "cloud"}})
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant Cloud",
stt_engine="cloud",
stt_language="en-US",
tts_engine="cloud",
tts_language="en-US",
tts_voice="james_earl_jones",
)

View File

@ -92,6 +92,7 @@ async def test_select_entity_changing_pipelines(
assert state.state == "preferred" assert state.state == "preferred"
assert state.attributes["options"] == [ assert state.attributes["options"] == [
"preferred", "preferred",
"Home Assistant",
pipeline_1.name, pipeline_1.name,
pipeline_2.name, pipeline_2.name,
] ]
@ -122,4 +123,8 @@ async def test_select_entity_changing_pipelines(
state = hass.states.get("select.assist_pipeline_test_pipeline") state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state.state == "preferred" assert state.state == "preferred"
assert state.attributes["options"] == ["preferred", pipeline_1.name] assert state.attributes["options"] == [
"preferred",
"Home Assistant",
pipeline_1.name,
]

View File

@ -610,7 +610,7 @@ async def test_add_pipeline(
"tts_voice": "Arnold Schwarzenegger", "tts_voice": "Arnold Schwarzenegger",
} }
assert len(pipeline_store.data) == 1 assert len(pipeline_store.data) == 2
pipeline = pipeline_store.data[msg["result"]["id"]] pipeline = pipeline_store.data[msg["result"]["id"]]
assert pipeline == Pipeline( assert pipeline == Pipeline(
conversation_engine="test_conversation_engine", conversation_engine="test_conversation_engine",
@ -643,6 +643,7 @@ async def test_add_pipeline_missing_language(
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store pipeline_store = pipeline_data.pipeline_store
assert len(pipeline_store.data) == 1
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
@ -660,7 +661,7 @@ async def test_add_pipeline_missing_language(
) )
msg = await client.receive_json() msg = await client.receive_json()
assert not msg["success"] assert not msg["success"]
assert len(pipeline_store.data) == 0 assert len(pipeline_store.data) == 1
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
@ -678,7 +679,7 @@ async def test_add_pipeline_missing_language(
) )
msg = await client.receive_json() msg = await client.receive_json()
assert not msg["success"] assert not msg["success"]
assert len(pipeline_store.data) == 0 assert len(pipeline_store.data) == 1
async def test_delete_pipeline( async def test_delete_pipeline(
@ -725,7 +726,16 @@ async def test_delete_pipeline(
assert msg["success"] assert msg["success"]
pipeline_id_2 = msg["result"]["id"] pipeline_id_2 = msg["result"]["id"]
assert len(pipeline_store.data) == 2 assert len(pipeline_store.data) == 3
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/set_preferred",
"pipeline_id": pipeline_id_1,
}
)
msg = await client.receive_json()
assert msg["success"]
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
@ -748,7 +758,7 @@ async def test_delete_pipeline(
) )
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert len(pipeline_store.data) == 1 assert len(pipeline_store.data) == 2
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
@ -778,10 +788,18 @@ async def test_get_pipeline(
} }
) )
msg = await client.receive_json() msg = await client.receive_json()
assert not msg["success"] assert msg["success"]
assert msg["error"] == { assert msg["result"] == {
"code": "not_found", "conversation_engine": "homeassistant",
"message": "Unable to find pipeline_id None", "conversation_language": "en",
"id": ANY,
"language": "en",
"name": "Home Assistant",
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "james_earl_jones",
} }
await client.send_json_auto_id( await client.send_json_auto_id(
@ -814,27 +832,7 @@ async def test_get_pipeline(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
pipeline_id = msg["result"]["id"] pipeline_id = msg["result"]["id"]
assert len(pipeline_store.data) == 1 assert len(pipeline_store.data) == 2
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/get",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "test_conversation_engine",
"conversation_language": "test_language",
"id": pipeline_id,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"stt_language": "test_language",
"tts_engine": "test_tts_engine",
"tts_language": "test_language",
"tts_voice": "Arnold Schwarzenegger",
}
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
@ -863,31 +861,7 @@ async def test_list_pipelines(
) -> None: ) -> None:
"""Test we can list pipelines.""" """Test we can list pipelines."""
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
pipeline_data: PipelineData = hass.data[DOMAIN] hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"pipelines": [], "preferred_pipeline": None}
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "test_conversation_engine",
"conversation_language": "test_language",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"stt_language": "test_language",
"tts_engine": "test_tts_engine",
"tts_language": "test_language",
"tts_voice": "Arnold Schwarzenegger",
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 1
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json() msg = await client.receive_json()
@ -895,16 +869,16 @@ async def test_list_pipelines(
assert msg["result"] == { assert msg["result"] == {
"pipelines": [ "pipelines": [
{ {
"conversation_engine": "test_conversation_engine", "conversation_engine": "homeassistant",
"conversation_language": "test_language", "conversation_language": "en",
"id": ANY, "id": ANY,
"language": "test_language", "language": "en",
"name": "test_name", "name": "Home Assistant",
"stt_engine": "test_stt_engine", "stt_engine": "test",
"stt_language": "test_language", "stt_language": "en-US",
"tts_engine": "test_tts_engine", "tts_engine": "test",
"tts_language": "test_language", "tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger", "tts_voice": "james_earl_jones",
} }
], ],
"preferred_pipeline": ANY, "preferred_pipeline": ANY,
@ -958,7 +932,7 @@ async def test_update_pipeline(
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
pipeline_id = msg["result"]["id"] pipeline_id = msg["result"]["id"]
assert len(pipeline_store.data) == 1 assert len(pipeline_store.data) == 2
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
@ -990,7 +964,7 @@ async def test_update_pipeline(
"tts_voice": "new_tts_voice", "tts_voice": "new_tts_voice",
} }
assert len(pipeline_store.data) == 1 assert len(pipeline_store.data) == 2
pipeline = pipeline_store.data[pipeline_id] pipeline = pipeline_store.data[pipeline_id]
assert pipeline == Pipeline( assert pipeline == Pipeline(
conversation_engine="new_conversation_engine", conversation_engine="new_conversation_engine",
@ -1076,36 +1050,18 @@ async def test_set_preferred_pipeline(
assert msg["success"] assert msg["success"]
pipeline_id_1 = msg["result"]["id"] pipeline_id_1 = msg["result"]["id"]
await client.send_json_auto_id( assert pipeline_store.async_get_preferred_item() != pipeline_id_1
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "test_conversation_engine",
"conversation_language": "test_language",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"stt_language": "test_language",
"tts_engine": "test_tts_engine",
"tts_language": "test_language",
"tts_voice": "Arnold Schwarzenegger",
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id_2 = msg["result"]["id"]
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
"type": "assist_pipeline/pipeline/set_preferred", "type": "assist_pipeline/pipeline/set_preferred",
"pipeline_id": pipeline_id_2, "pipeline_id": pipeline_id_1,
} }
) )
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert pipeline_store.async_get_preferred_item() == pipeline_id_2 assert pipeline_store.async_get_preferred_item() == pipeline_id_1
async def test_set_preferred_pipeline_wrong_id( async def test_set_preferred_pipeline_wrong_id(