Create a default assist pipeline on start (#91947)

* Create a default assist pipeline on start

* Minor adjustments

* Address review comments

* Remove tts.async_get_agent

* Fix bugs, improve test coverage
This commit is contained in:
Erik Montnemery 2023-04-24 20:00:52 +02:00 committed by GitHub
parent 4b619f7251
commit b601fb17d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 523 additions and 179 deletions

View File

@ -51,7 +51,7 @@ async def async_pipeline_from_audio_stream(
tts_audio_output: str | None = None,
) -> None:
"""Create an audio pipeline from an audio stream."""
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
raise PipelineNotFound(
"pipeline_not_found", f"Pipeline {pipeline_id} not found"

View File

@ -24,7 +24,11 @@ from homeassistant.helpers.collection import (
StorageCollectionWebsocket,
)
from homeassistant.helpers.storage import Store
from homeassistant.util import dt as dt_util, ulid as ulid_util
from homeassistant.util import (
dt as dt_util,
language as language_util,
ulid as ulid_util,
)
from homeassistant.util.limited_size_dict import LimitedSizeDict
from .const import DOMAIN
@ -71,37 +75,109 @@ STORED_PIPELINE_RUNS = 10
SAVE_DELAY = 10
async def async_get_pipeline(
async def _async_create_default_pipeline(
hass: HomeAssistant, pipeline_store: PipelineStorageCollection
) -> Pipeline:
"""Create a default pipeline.
The default pipeline will use the homeassistant conversation agent and the
default stt / tts engines.
"""
conversation_language = "en"
pipeline_language = "en"
pipeline_name = "Home Assistant"
stt_engine_id = None
stt_language = None
tts_engine_id = None
tts_language = None
tts_voice = None
# Find a matching language supported by the Home Assistant conversation agent
conversation_languages = language_util.matches(
hass.config.language,
await conversation.async_get_conversation_languages(
hass, conversation.HOME_ASSISTANT_AGENT
),
country=hass.config.country,
)
if conversation_languages:
pipeline_language = hass.config.language
conversation_language = conversation_languages[0]
if (stt_engine_id := stt.async_default_engine(hass)) is not None and (
stt_engine := stt.async_get_speech_to_text_engine(
hass,
stt_engine_id,
)
):
stt_languages = language_util.matches(
pipeline_language,
stt_engine.supported_languages,
country=hass.config.country,
)
if stt_languages:
stt_language = stt_languages[0]
else:
_LOGGER.debug(
"Speech to text engine '%s' does not support language '%s'",
stt_engine_id,
pipeline_language,
)
stt_engine_id = None
if (tts_engine_id := tts.async_default_engine(hass)) is not None and (
tts_engine := tts.get_engine_instance(
hass,
tts_engine_id,
)
):
tts_languages = language_util.matches(
pipeline_language,
tts_engine.supported_languages,
country=hass.config.country,
)
if tts_languages:
tts_language = tts_languages[0]
tts_voices = tts_engine.async_get_supported_voices(tts_language)
if tts_voices:
tts_voice = tts_voices[0].voice_id
else:
_LOGGER.debug(
"Text to speech engine '%s' does not support language '%s'",
tts_engine_id,
pipeline_language,
)
tts_engine_id = None
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
pipeline_name = "Home Assistant Cloud"
return await pipeline_store.async_create_item(
{
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_language": conversation_language,
"language": hass.config.language,
"name": pipeline_name,
"stt_engine": stt_engine_id,
"stt_language": stt_language,
"tts_engine": tts_engine_id,
"tts_language": tts_language,
"tts_voice": tts_voice,
}
)
@callback
def async_get_pipeline(
hass: HomeAssistant, pipeline_id: str | None = None
) -> Pipeline | None:
"""Get a pipeline by id or create one for a language."""
"""Get a pipeline by id or the preferred pipeline."""
pipeline_data: PipelineData = hass.data[DOMAIN]
if pipeline_id is None:
# A pipeline was not specified, use the preferred one
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
if pipeline_id is None:
# There's no preferred pipeline, construct a pipeline for the
# configured language
stt_engine = stt.async_default_provider(hass)
stt_language = hass.config.language if stt_engine else None
tts_engine = tts.async_default_engine(hass)
tts_language = hass.config.language if tts_engine else None
return await pipeline_data.pipeline_store.async_create_item(
{
"conversation_engine": None,
"conversation_language": None,
"language": hass.config.language,
"name": hass.config.language,
"stt_engine": stt_engine,
"stt_language": stt_language,
"tts_engine": tts_engine,
"tts_language": tts_language,
"tts_voice": None,
}
)
return pipeline_data.pipeline_store.data.get(pipeline_id)
@ -635,7 +711,7 @@ class PipelinePreferred(CollectionError):
class SerializedPipelineStorageCollection(SerializedStorageCollection):
"""Serialized pipeline storage collection."""
preferred_item: str | None
preferred_item: str
class PipelineStorageCollection(
@ -643,11 +719,13 @@ class PipelineStorageCollection(
):
"""Pipeline storage collection."""
_preferred_item: str | None = None
_preferred_item: str
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
"""Load the data."""
if not (data := await super()._async_load_data()):
pipeline = await _async_create_default_pipeline(self.hass, self)
self._preferred_item = pipeline.id
return data
self._preferred_item = data["preferred_item"]
@ -671,8 +749,6 @@ class PipelineStorageCollection(
def _create_item(self, item_id: str, data: dict) -> Pipeline:
"""Create an item from validated config."""
if self._preferred_item is None:
self._preferred_item = item_id
return Pipeline(id=item_id, **data)
def _deserialize_item(self, data: dict) -> Pipeline:
@ -690,7 +766,7 @@ class PipelineStorageCollection(
await super().async_delete_item(item_id)
@callback
def async_get_preferred_item(self) -> str | None:
def async_get_preferred_item(self) -> str:
"""Get the id of the preferred item."""
return self._preferred_item

View File

@ -85,7 +85,7 @@ async def websocket_run(
) -> None:
"""Run a pipeline."""
pipeline_id = msg.get("pipeline")
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
if pipeline is None:
connection.send_error(
msg["id"],

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
import logging
import re
@ -117,14 +118,25 @@ def async_unset_agent(
async def async_get_conversation_languages(
hass: HomeAssistant,
hass: HomeAssistant, agent_id: str | None = None
) -> set[str] | Literal["*"]:
"""Return a set with the union of languages supported by conversation agents."""
"""Return languages supported by conversation agents.
If an agent is specified, returns a set of languages supported by that agent.
If no agent is specified, return a set with the union of languages supported by
all conversation agents.
"""
agent_manager = _get_agent_manager(hass)
languages = set()
for agent_info in agent_manager.async_get_agent_info():
agent = await agent_manager.async_get_agent(agent_info.id)
agent_ids: Iterable[str]
if agent_id is None:
agent_ids = iter(info.id for info in agent_manager.async_get_agent_info())
else:
agent_ids = (agent_id,)
for _agent_id in agent_ids:
agent = await agent_manager.async_get_agent(_agent_id)
if agent.supported_languages == MATCH_ALL:
return MATCH_ALL
for language_tag in agent.supported_languages:

View File

@ -1 +1,56 @@
"""Tests for the Voice Assistant integration."""
MANY_LANGUAGES = [
"ar",
"bg",
"bn",
"ca",
"cs",
"da",
"de",
"de-CH",
"el",
"en",
"es",
"fa",
"fi",
"fr",
"fr-CA",
"gl",
"gu",
"he",
"hi",
"hr",
"hu",
"id",
"is",
"it",
"ka",
"kn",
"lb",
"lt",
"lv",
"ml",
"mn",
"ms",
"nb",
"nl",
"pl",
"pt",
"pt-br",
"ro",
"ru",
"sk",
"sl",
"sr",
"sv",
"sw",
"te",
"tr",
"uk",
"ur",
"vi",
"zh-cn",
"zh-hk",
"zh-tw",
]

View File

@ -11,9 +11,8 @@ from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
from homeassistant.config_entries import ConfigEntry, ConfigFlow
from homeassistant.core import HomeAssistant
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import (
@ -36,6 +35,8 @@ _TRANSCRIPT = "test transcript"
class BaseProvider:
"""Mock STT provider."""
_supported_languages = ["en-US"]
def __init__(self, text: str) -> None:
"""Init test provider."""
self.text = text
@ -44,7 +45,7 @@ class BaseProvider:
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return ["en-US"]
return self._supported_languages
@property
def supported_formats(self) -> list[stt.AudioFormats]:
@ -96,6 +97,13 @@ class MockTTSProvider(tts.Provider):
"""Mock TTS provider."""
name = "Test"
_supported_languages = ["en-US"]
_supported_voices = {
"en-US": [
tts.Voice("james_earl_jones", "James Earl Jones"),
tts.Voice("fran_drescher", "Fran Drescher"),
]
}
@property
def default_language(self) -> str:
@ -105,7 +113,12 @@ class MockTTSProvider(tts.Provider):
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return ["en-US"]
return self._supported_languages
@callback
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
"""Return a list of supported voices for a language."""
return self._supported_voices.get(language)
@property
def supported_options(self) -> list[str]:
@ -119,19 +132,21 @@ class MockTTSProvider(tts.Provider):
return ("mp3", b"")
class MockTTS(MockPlatform):
class MockTTSPlatform(MockPlatform):
"""A mock TTS platform."""
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> tts.Provider:
"""Set up a mock speech component."""
return MockTTSProvider()
def __init__(self, *, async_get_engine, **kwargs):
"""Initialize the tts platform."""
super().__init__(**kwargs)
self.async_get_engine = async_get_engine
@pytest.fixture
async def mock_tts_provider(hass) -> MockTTSProvider:
"""Mock TTS provider."""
return MockTTSProvider()
@pytest.fixture
@ -169,10 +184,11 @@ def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
@pytest.fixture
async def init_components(
async def init_supporting_components(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_stt_provider_entity: MockSttProviderEntity,
mock_tts_provider: MockTTSProvider,
config_flow_fixture,
init_cache_dir_side_effect, # noqa: F811
mock_get_cache_files, # noqa: F811
@ -210,7 +226,13 @@ async def init_components(
async_unload_entry=async_unload_entry_init,
),
)
mock_platform(hass, "test.tts", MockTTS())
mock_platform(
hass,
"test.tts",
MockTTSPlatform(
async_get_engine=AsyncMock(return_value=mock_tts_provider),
),
)
mock_platform(
hass,
"test.stt",
@ -224,7 +246,6 @@ async def init_components(
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
assert await async_setup_component(hass, "media_source", {})
assert await async_setup_component(hass, "assist_pipeline", {})
config_entry = MockConfigEntry(domain="test")
config_entry.add_to_hass(hass)
@ -232,6 +253,13 @@ async def init_components(
await hass.async_block_till_done()
@pytest.fixture
async def init_components(hass: HomeAssistant, init_supporting_components):
"""Initialize relevant components with empty configs."""
assert await async_setup_component(hass, "assist_pipeline", {})
@pytest.fixture
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
"""Return pipeline storage collection."""

View File

@ -4,7 +4,7 @@
dict({
'data': dict({
'language': 'en',
'pipeline': 'en',
'pipeline': 'Home Assistant',
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
@ -34,7 +34,7 @@
'data': dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': None,
'language': 'en',
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
@ -64,18 +64,18 @@
dict({
'data': dict({
'engine': 'test',
'language': 'en',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': None,
'voice': 'james_earl_jones',
}),
'type': <PipelineEventType.TTS_START: 'tts-start'>,
}),
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
}),
'type': <PipelineEventType.TTS_END: 'tts-end'>,

View File

@ -2,7 +2,7 @@
# name: test_audio_pipeline
dict({
'language': 'en',
'pipeline': 'en',
'pipeline': 'Home Assistant',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
@ -33,7 +33,7 @@
dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': None,
'language': 'en',
})
# ---
# name: test_audio_pipeline.4
@ -61,24 +61,24 @@
# name: test_audio_pipeline.5
dict({
'engine': 'test',
'language': 'en',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': None,
'voice': 'james_earl_jones',
})
# ---
# name: test_audio_pipeline.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
})
# ---
# name: test_audio_pipeline_debug
dict({
'language': 'en',
'pipeline': 'en',
'pipeline': 'Home Assistant',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
@ -109,7 +109,7 @@
dict({
'engine': 'homeassistant',
'intent_input': 'test transcript',
'language': None,
'language': 'en',
})
# ---
# name: test_audio_pipeline_debug.4
@ -137,24 +137,24 @@
# name: test_audio_pipeline_debug.5
dict({
'engine': 'test',
'language': 'en',
'language': 'en-US',
'tts_input': "Sorry, I couldn't understand that",
'voice': None,
'voice': 'james_earl_jones',
})
# ---
# name: test_audio_pipeline_debug.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
}),
})
# ---
# name: test_intent_failed
dict({
'language': 'en',
'pipeline': 'en',
'pipeline': 'Home Assistant',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 30,
@ -165,13 +165,13 @@
dict({
'engine': 'homeassistant',
'intent_input': 'Are the lights on?',
'language': None,
'language': 'en',
})
# ---
# name: test_intent_timeout
dict({
'language': 'en',
'pipeline': 'en',
'pipeline': 'Home Assistant',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 0.1,
@ -182,7 +182,7 @@
dict({
'engine': 'homeassistant',
'intent_input': 'Are the lights on?',
'language': None,
'language': 'en',
})
# ---
# name: test_intent_timeout.2
@ -217,7 +217,7 @@
# name: test_stt_stream_failed
dict({
'language': 'en',
'pipeline': 'en',
'pipeline': 'Home Assistant',
'runner_data': dict({
'stt_binary_handler_id': 1,
'timeout': 30,
@ -240,7 +240,7 @@
# name: test_text_only_pipeline
dict({
'language': 'en',
'pipeline': 'en',
'pipeline': 'Home Assistant',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 30,
@ -251,7 +251,7 @@
dict({
'engine': 'homeassistant',
'intent_input': 'Are the lights on?',
'language': None,
'language': 'en',
})
# ---
# name: test_text_only_pipeline.2
@ -285,7 +285,7 @@
# name: test_tts_failed
dict({
'language': 'en',
'pipeline': 'en',
'pipeline': 'Home Assistant',
'runner_data': dict({
'stt_binary_handler_id': None,
'timeout': 30,
@ -295,8 +295,8 @@
# name: test_tts_failed.1
dict({
'engine': 'test',
'language': 'en',
'language': 'en-US',
'tts_input': 'Lights are on.',
'voice': None,
'voice': 'james_earl_jones',
})
# ---

View File

@ -1,10 +1,14 @@
"""Websocket tests for Voice Assistant integration."""
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_KEY,
STORAGE_VERSION,
Pipeline,
PipelineData,
PipelineStorageCollection,
async_get_pipeline,
@ -13,7 +17,10 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.storage import Store
from homeassistant.setup import async_setup_component
from tests.common import flush_store
from . import MANY_LANGUAGES
from .conftest import MockSttPlatform, MockSttProvider, MockTTSPlatform, MockTTSProvider
from tests.common import MockModule, flush_store, mock_integration, mock_platform
async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
@ -60,17 +67,17 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
store1 = pipeline_data.pipeline_store
for pipeline in pipelines:
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
assert len(store1.data) == 3
assert len(store1.data) == 4 # 3 manually created plus a default pipeline
assert store1.async_get_preferred_item() == list(store1.data)[0]
await store1.async_delete_item(pipeline_ids[1])
assert len(store1.data) == 2
assert len(store1.data) == 3
store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY))
await flush_store(store1.store)
await store2.async_load()
assert len(store2.data) == 2
assert len(store2.data) == 3
assert store1.data is not store2.data
assert store1.data == store2.data
@ -142,16 +149,221 @@ async def test_get_pipeline(hass: HomeAssistant) -> None:
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 0
# Test a pipeline is created
pipeline = await async_get_pipeline(hass, None)
assert len(store.data) == 1
# Test we get the same pipeline again
assert pipeline is await async_get_pipeline(hass, None)
assert len(store.data) == 1
# Test we get the preferred pipeline if none is specified
pipeline = async_get_pipeline(hass, None)
assert pipeline.id == store.async_get_preferred_item()
# Test getting a specific pipeline
assert pipeline is await async_get_pipeline(hass, pipeline.id)
assert pipeline is async_get_pipeline(hass, pipeline.id)
@pytest.mark.parametrize(
("ha_language", "ha_country", "conv_language", "pipeline_language"),
[
("en", None, "en", "en"),
("de", "de", "de", "de"),
("de", "ch", "de-CH", "de"),
("en", "us", "en", "en"),
("en", "uk", "en", "en"),
("pt", "pt", "pt", "pt"),
("pt", "br", "pt-br", "pt"),
],
)
async def test_default_pipeline_no_stt_tts(
hass: HomeAssistant,
ha_language: str,
ha_country: str | None,
conv_language: str,
pipeline_language: str,
) -> None:
"""Test async_get_pipeline."""
hass.config.country = ha_country
hass.config.language = ha_language
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine=None,
tts_language=None,
tts_voice=None,
)
@pytest.mark.parametrize(
(
"ha_language",
"ha_country",
"conv_language",
"pipeline_language",
"stt_language",
"tts_language",
),
[
("en", None, "en", "en", "en", "en"),
("de", "de", "de", "de", "de", "de"),
("de", "ch", "de-CH", "de", "de-CH", "de-CH"),
("en", "us", "en", "en", "en", "en"),
("en", "uk", "en", "en", "en", "en"),
("pt", "pt", "pt", "pt", "pt", "pt"),
("pt", "br", "pt-br", "pt", "pt-br", "pt-br"),
],
)
async def test_default_pipeline(
hass: HomeAssistant,
init_supporting_components,
mock_stt_provider: MockSttProvider,
mock_tts_provider: MockTTSProvider,
ha_language: str,
ha_country: str | None,
conv_language: str,
pipeline_language: str,
stt_language: str,
tts_language: str,
) -> None:
"""Test async_get_pipeline."""
hass.config.country = ha_country
hass.config.language = ha_language
with patch.object(
mock_stt_provider, "_supported_languages", MANY_LANGUAGES
), patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language=conv_language,
id=pipeline.id,
language=pipeline_language,
name="Home Assistant",
stt_engine="test",
stt_language=stt_language,
tts_engine="test",
tts_language=tts_language,
tts_voice=None,
)
async def test_default_pipeline_unsupported_stt_language(
hass: HomeAssistant,
init_supporting_components,
mock_stt_provider: MockSttProvider,
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant",
stt_engine=None,
stt_language=None,
tts_engine="test",
tts_language="en-US",
tts_voice="james_earl_jones",
)
async def test_default_pipeline_unsupported_tts_language(
hass: HomeAssistant,
init_supporting_components,
mock_tts_provider: MockTTSProvider,
) -> None:
"""Test async_get_pipeline."""
with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]):
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant",
stt_engine="test",
stt_language="en-US",
tts_engine=None,
tts_language=None,
tts_voice=None,
)
async def test_default_pipeline_cloud(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_tts_provider: MockTTSProvider,
) -> None:
"""Test async_get_pipeline."""
mock_integration(hass, MockModule("cloud"))
mock_platform(
hass,
"cloud.tts",
MockTTSPlatform(
async_get_engine=AsyncMock(return_value=mock_tts_provider),
),
)
mock_platform(
hass,
"cloud.stt",
MockSttPlatform(
async_get_engine=AsyncMock(return_value=mock_stt_provider),
),
)
mock_platform(hass, "test.config_flow")
assert await async_setup_component(hass, "tts", {"tts": {"platform": "cloud"}})
assert await async_setup_component(hass, "stt", {"stt": {"platform": "cloud"}})
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant Cloud",
stt_engine="cloud",
stt_language="en-US",
tts_engine="cloud",
tts_language="en-US",
tts_voice="james_earl_jones",
)

View File

@ -92,6 +92,7 @@ async def test_select_entity_changing_pipelines(
assert state.state == "preferred"
assert state.attributes["options"] == [
"preferred",
"Home Assistant",
pipeline_1.name,
pipeline_2.name,
]
@ -122,4 +123,8 @@ async def test_select_entity_changing_pipelines(
state = hass.states.get("select.assist_pipeline_test_pipeline")
assert state.state == "preferred"
assert state.attributes["options"] == ["preferred", pipeline_1.name]
assert state.attributes["options"] == [
"preferred",
"Home Assistant",
pipeline_1.name,
]

View File

@ -610,7 +610,7 @@ async def test_add_pipeline(
"tts_voice": "Arnold Schwarzenegger",
}
assert len(pipeline_store.data) == 1
assert len(pipeline_store.data) == 2
pipeline = pipeline_store.data[msg["result"]["id"]]
assert pipeline == Pipeline(
conversation_engine="test_conversation_engine",
@ -643,6 +643,7 @@ async def test_add_pipeline_missing_language(
client = await hass_ws_client(hass)
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
assert len(pipeline_store.data) == 1
await client.send_json_auto_id(
{
@ -660,7 +661,7 @@ async def test_add_pipeline_missing_language(
)
msg = await client.receive_json()
assert not msg["success"]
assert len(pipeline_store.data) == 0
assert len(pipeline_store.data) == 1
await client.send_json_auto_id(
{
@ -678,7 +679,7 @@ async def test_add_pipeline_missing_language(
)
msg = await client.receive_json()
assert not msg["success"]
assert len(pipeline_store.data) == 0
assert len(pipeline_store.data) == 1
async def test_delete_pipeline(
@ -725,7 +726,16 @@ async def test_delete_pipeline(
assert msg["success"]
pipeline_id_2 = msg["result"]["id"]
assert len(pipeline_store.data) == 2
assert len(pipeline_store.data) == 3
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/set_preferred",
"pipeline_id": pipeline_id_1,
}
)
msg = await client.receive_json()
assert msg["success"]
await client.send_json_auto_id(
{
@ -748,7 +758,7 @@ async def test_delete_pipeline(
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 1
assert len(pipeline_store.data) == 2
await client.send_json_auto_id(
{
@ -778,10 +788,18 @@ async def test_get_pipeline(
}
)
msg = await client.receive_json()
assert not msg["success"]
assert msg["error"] == {
"code": "not_found",
"message": "Unable to find pipeline_id None",
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "homeassistant",
"conversation_language": "en",
"id": ANY,
"language": "en",
"name": "Home Assistant",
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "james_earl_jones",
}
await client.send_json_auto_id(
@ -814,27 +832,7 @@ async def test_get_pipeline(
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
assert len(pipeline_store.data) == 1
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/get",
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {
"conversation_engine": "test_conversation_engine",
"conversation_language": "test_language",
"id": pipeline_id,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"stt_language": "test_language",
"tts_engine": "test_tts_engine",
"tts_language": "test_language",
"tts_voice": "Arnold Schwarzenegger",
}
assert len(pipeline_store.data) == 2
await client.send_json_auto_id(
{
@ -863,31 +861,7 @@ async def test_list_pipelines(
) -> None:
"""Test we can list pipelines."""
client = await hass_ws_client(hass)
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"pipelines": [], "preferred_pipeline": None}
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "test_conversation_engine",
"conversation_language": "test_language",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"stt_language": "test_language",
"tts_engine": "test_tts_engine",
"tts_language": "test_language",
"tts_voice": "Arnold Schwarzenegger",
}
)
msg = await client.receive_json()
assert msg["success"]
assert len(pipeline_store.data) == 1
hass.data[DOMAIN]
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json()
@ -895,16 +869,16 @@ async def test_list_pipelines(
assert msg["result"] == {
"pipelines": [
{
"conversation_engine": "test_conversation_engine",
"conversation_language": "test_language",
"conversation_engine": "homeassistant",
"conversation_language": "en",
"id": ANY,
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"stt_language": "test_language",
"tts_engine": "test_tts_engine",
"tts_language": "test_language",
"tts_voice": "Arnold Schwarzenegger",
"language": "en",
"name": "Home Assistant",
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "james_earl_jones",
}
],
"preferred_pipeline": ANY,
@ -958,7 +932,7 @@ async def test_update_pipeline(
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
assert len(pipeline_store.data) == 1
assert len(pipeline_store.data) == 2
await client.send_json_auto_id(
{
@ -990,7 +964,7 @@ async def test_update_pipeline(
"tts_voice": "new_tts_voice",
}
assert len(pipeline_store.data) == 1
assert len(pipeline_store.data) == 2
pipeline = pipeline_store.data[pipeline_id]
assert pipeline == Pipeline(
conversation_engine="new_conversation_engine",
@ -1076,36 +1050,18 @@ async def test_set_preferred_pipeline(
assert msg["success"]
pipeline_id_1 = msg["result"]["id"]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "test_conversation_engine",
"conversation_language": "test_language",
"language": "test_language",
"name": "test_name",
"stt_engine": "test_stt_engine",
"stt_language": "test_language",
"tts_engine": "test_tts_engine",
"tts_language": "test_language",
"tts_voice": "Arnold Schwarzenegger",
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id_2 = msg["result"]["id"]
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
assert pipeline_store.async_get_preferred_item() != pipeline_id_1
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/set_preferred",
"pipeline_id": pipeline_id_2,
"pipeline_id": pipeline_id_1,
}
)
msg = await client.receive_json()
assert msg["success"]
assert pipeline_store.async_get_preferred_item() == pipeline_id_2
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
async def test_set_preferred_pipeline_wrong_id(