mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 09:17:53 +00:00
Create a default assist pipeline on start (#91947)
* Create a default assist pipeline on start * Minor adjustments * Address review comments * Remove tts.async_get_agent * Fix bugs, improve test coverage
This commit is contained in:
parent
4b619f7251
commit
b601fb17d3
@ -51,7 +51,7 @@ async def async_pipeline_from_audio_stream(
|
||||
tts_audio_output: str | None = None,
|
||||
) -> None:
|
||||
"""Create an audio pipeline from an audio stream."""
|
||||
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||
if pipeline is None:
|
||||
raise PipelineNotFound(
|
||||
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
||||
|
@ -24,7 +24,11 @@ from homeassistant.helpers.collection import (
|
||||
StorageCollectionWebsocket,
|
||||
)
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
||||
from homeassistant.util import (
|
||||
dt as dt_util,
|
||||
language as language_util,
|
||||
ulid as ulid_util,
|
||||
)
|
||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||
|
||||
from .const import DOMAIN
|
||||
@ -71,37 +75,109 @@ STORED_PIPELINE_RUNS = 10
|
||||
SAVE_DELAY = 10
|
||||
|
||||
|
||||
async def async_get_pipeline(
|
||||
async def _async_create_default_pipeline(
|
||||
hass: HomeAssistant, pipeline_store: PipelineStorageCollection
|
||||
) -> Pipeline:
|
||||
"""Create a default pipeline.
|
||||
|
||||
The default pipeline will use the homeassistant conversation agent and the
|
||||
default stt / tts engines.
|
||||
"""
|
||||
conversation_language = "en"
|
||||
pipeline_language = "en"
|
||||
pipeline_name = "Home Assistant"
|
||||
stt_engine_id = None
|
||||
stt_language = None
|
||||
tts_engine_id = None
|
||||
tts_language = None
|
||||
tts_voice = None
|
||||
|
||||
# Find a matching language supported by the Home Assistant conversation agent
|
||||
conversation_languages = language_util.matches(
|
||||
hass.config.language,
|
||||
await conversation.async_get_conversation_languages(
|
||||
hass, conversation.HOME_ASSISTANT_AGENT
|
||||
),
|
||||
country=hass.config.country,
|
||||
)
|
||||
if conversation_languages:
|
||||
pipeline_language = hass.config.language
|
||||
conversation_language = conversation_languages[0]
|
||||
|
||||
if (stt_engine_id := stt.async_default_engine(hass)) is not None and (
|
||||
stt_engine := stt.async_get_speech_to_text_engine(
|
||||
hass,
|
||||
stt_engine_id,
|
||||
)
|
||||
):
|
||||
stt_languages = language_util.matches(
|
||||
pipeline_language,
|
||||
stt_engine.supported_languages,
|
||||
country=hass.config.country,
|
||||
)
|
||||
if stt_languages:
|
||||
stt_language = stt_languages[0]
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Speech to text engine '%s' does not support language '%s'",
|
||||
stt_engine_id,
|
||||
pipeline_language,
|
||||
)
|
||||
stt_engine_id = None
|
||||
|
||||
if (tts_engine_id := tts.async_default_engine(hass)) is not None and (
|
||||
tts_engine := tts.get_engine_instance(
|
||||
hass,
|
||||
tts_engine_id,
|
||||
)
|
||||
):
|
||||
tts_languages = language_util.matches(
|
||||
pipeline_language,
|
||||
tts_engine.supported_languages,
|
||||
country=hass.config.country,
|
||||
)
|
||||
if tts_languages:
|
||||
tts_language = tts_languages[0]
|
||||
tts_voices = tts_engine.async_get_supported_voices(tts_language)
|
||||
if tts_voices:
|
||||
tts_voice = tts_voices[0].voice_id
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Text to speech engine '%s' does not support language '%s'",
|
||||
tts_engine_id,
|
||||
pipeline_language,
|
||||
)
|
||||
tts_engine_id = None
|
||||
|
||||
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
|
||||
pipeline_name = "Home Assistant Cloud"
|
||||
|
||||
return await pipeline_store.async_create_item(
|
||||
{
|
||||
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
||||
"conversation_language": conversation_language,
|
||||
"language": hass.config.language,
|
||||
"name": pipeline_name,
|
||||
"stt_engine": stt_engine_id,
|
||||
"stt_language": stt_language,
|
||||
"tts_engine": tts_engine_id,
|
||||
"tts_language": tts_language,
|
||||
"tts_voice": tts_voice,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_pipeline(
|
||||
hass: HomeAssistant, pipeline_id: str | None = None
|
||||
) -> Pipeline | None:
|
||||
"""Get a pipeline by id or create one for a language."""
|
||||
"""Get a pipeline by id or the preferred pipeline."""
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
|
||||
if pipeline_id is None:
|
||||
# A pipeline was not specified, use the preferred one
|
||||
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
||||
|
||||
if pipeline_id is None:
|
||||
# There's no preferred pipeline, construct a pipeline for the
|
||||
# configured language
|
||||
stt_engine = stt.async_default_provider(hass)
|
||||
stt_language = hass.config.language if stt_engine else None
|
||||
tts_engine = tts.async_default_engine(hass)
|
||||
tts_language = hass.config.language if tts_engine else None
|
||||
return await pipeline_data.pipeline_store.async_create_item(
|
||||
{
|
||||
"conversation_engine": None,
|
||||
"conversation_language": None,
|
||||
"language": hass.config.language,
|
||||
"name": hass.config.language,
|
||||
"stt_engine": stt_engine,
|
||||
"stt_language": stt_language,
|
||||
"tts_engine": tts_engine,
|
||||
"tts_language": tts_language,
|
||||
"tts_voice": None,
|
||||
}
|
||||
)
|
||||
|
||||
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||
|
||||
|
||||
@ -635,7 +711,7 @@ class PipelinePreferred(CollectionError):
|
||||
class SerializedPipelineStorageCollection(SerializedStorageCollection):
|
||||
"""Serialized pipeline storage collection."""
|
||||
|
||||
preferred_item: str | None
|
||||
preferred_item: str
|
||||
|
||||
|
||||
class PipelineStorageCollection(
|
||||
@ -643,11 +719,13 @@ class PipelineStorageCollection(
|
||||
):
|
||||
"""Pipeline storage collection."""
|
||||
|
||||
_preferred_item: str | None = None
|
||||
_preferred_item: str
|
||||
|
||||
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
|
||||
"""Load the data."""
|
||||
if not (data := await super()._async_load_data()):
|
||||
pipeline = await _async_create_default_pipeline(self.hass, self)
|
||||
self._preferred_item = pipeline.id
|
||||
return data
|
||||
|
||||
self._preferred_item = data["preferred_item"]
|
||||
@ -671,8 +749,6 @@ class PipelineStorageCollection(
|
||||
|
||||
def _create_item(self, item_id: str, data: dict) -> Pipeline:
|
||||
"""Create an item from validated config."""
|
||||
if self._preferred_item is None:
|
||||
self._preferred_item = item_id
|
||||
return Pipeline(id=item_id, **data)
|
||||
|
||||
def _deserialize_item(self, data: dict) -> Pipeline:
|
||||
@ -690,7 +766,7 @@ class PipelineStorageCollection(
|
||||
await super().async_delete_item(item_id)
|
||||
|
||||
@callback
|
||||
def async_get_preferred_item(self) -> str | None:
|
||||
def async_get_preferred_item(self) -> str:
|
||||
"""Get the id of the preferred item."""
|
||||
return self._preferred_item
|
||||
|
||||
|
@ -85,7 +85,7 @@ async def websocket_run(
|
||||
) -> None:
|
||||
"""Run a pipeline."""
|
||||
pipeline_id = msg.get("pipeline")
|
||||
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||
if pipeline is None:
|
||||
connection.send_error(
|
||||
msg["id"],
|
||||
|
@ -2,6 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import re
|
||||
@ -117,14 +118,25 @@ def async_unset_agent(
|
||||
|
||||
|
||||
async def async_get_conversation_languages(
|
||||
hass: HomeAssistant,
|
||||
hass: HomeAssistant, agent_id: str | None = None
|
||||
) -> set[str] | Literal["*"]:
|
||||
"""Return a set with the union of languages supported by conversation agents."""
|
||||
"""Return languages supported by conversation agents.
|
||||
|
||||
If an agent is specified, returns a set of languages supported by that agent.
|
||||
If no agent is specified, return a set with the union of languages supported by
|
||||
all conversation agents.
|
||||
"""
|
||||
agent_manager = _get_agent_manager(hass)
|
||||
languages = set()
|
||||
|
||||
for agent_info in agent_manager.async_get_agent_info():
|
||||
agent = await agent_manager.async_get_agent(agent_info.id)
|
||||
agent_ids: Iterable[str]
|
||||
if agent_id is None:
|
||||
agent_ids = iter(info.id for info in agent_manager.async_get_agent_info())
|
||||
else:
|
||||
agent_ids = (agent_id,)
|
||||
|
||||
for _agent_id in agent_ids:
|
||||
agent = await agent_manager.async_get_agent(_agent_id)
|
||||
if agent.supported_languages == MATCH_ALL:
|
||||
return MATCH_ALL
|
||||
for language_tag in agent.supported_languages:
|
||||
|
@ -1 +1,56 @@
|
||||
"""Tests for the Voice Assistant integration."""
|
||||
|
||||
MANY_LANGUAGES = [
|
||||
"ar",
|
||||
"bg",
|
||||
"bn",
|
||||
"ca",
|
||||
"cs",
|
||||
"da",
|
||||
"de",
|
||||
"de-CH",
|
||||
"el",
|
||||
"en",
|
||||
"es",
|
||||
"fa",
|
||||
"fi",
|
||||
"fr",
|
||||
"fr-CA",
|
||||
"gl",
|
||||
"gu",
|
||||
"he",
|
||||
"hi",
|
||||
"hr",
|
||||
"hu",
|
||||
"id",
|
||||
"is",
|
||||
"it",
|
||||
"ka",
|
||||
"kn",
|
||||
"lb",
|
||||
"lt",
|
||||
"lv",
|
||||
"ml",
|
||||
"mn",
|
||||
"ms",
|
||||
"nb",
|
||||
"nl",
|
||||
"pl",
|
||||
"pt",
|
||||
"pt-br",
|
||||
"ro",
|
||||
"ru",
|
||||
"sk",
|
||||
"sl",
|
||||
"sr",
|
||||
"sv",
|
||||
"sw",
|
||||
"te",
|
||||
"tr",
|
||||
"uk",
|
||||
"ur",
|
||||
"vi",
|
||||
"zh-cn",
|
||||
"zh-hk",
|
||||
"zh-tw",
|
||||
]
|
||||
|
@ -11,9 +11,8 @@ from homeassistant.components import stt, tts
|
||||
from homeassistant.components.assist_pipeline import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import (
|
||||
@ -36,6 +35,8 @@ _TRANSCRIPT = "test transcript"
|
||||
class BaseProvider:
|
||||
"""Mock STT provider."""
|
||||
|
||||
_supported_languages = ["en-US"]
|
||||
|
||||
def __init__(self, text: str) -> None:
|
||||
"""Init test provider."""
|
||||
self.text = text
|
||||
@ -44,7 +45,7 @@ class BaseProvider:
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return a list of supported languages."""
|
||||
return ["en-US"]
|
||||
return self._supported_languages
|
||||
|
||||
@property
|
||||
def supported_formats(self) -> list[stt.AudioFormats]:
|
||||
@ -96,6 +97,13 @@ class MockTTSProvider(tts.Provider):
|
||||
"""Mock TTS provider."""
|
||||
|
||||
name = "Test"
|
||||
_supported_languages = ["en-US"]
|
||||
_supported_voices = {
|
||||
"en-US": [
|
||||
tts.Voice("james_earl_jones", "James Earl Jones"),
|
||||
tts.Voice("fran_drescher", "Fran Drescher"),
|
||||
]
|
||||
}
|
||||
|
||||
@property
|
||||
def default_language(self) -> str:
|
||||
@ -105,7 +113,12 @@ class MockTTSProvider(tts.Provider):
|
||||
@property
|
||||
def supported_languages(self) -> list[str]:
|
||||
"""Return list of supported languages."""
|
||||
return ["en-US"]
|
||||
return self._supported_languages
|
||||
|
||||
@callback
|
||||
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
||||
"""Return a list of supported voices for a language."""
|
||||
return self._supported_voices.get(language)
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
@ -119,19 +132,21 @@ class MockTTSProvider(tts.Provider):
|
||||
return ("mp3", b"")
|
||||
|
||||
|
||||
class MockTTS(MockPlatform):
|
||||
class MockTTSPlatform(MockPlatform):
|
||||
"""A mock TTS platform."""
|
||||
|
||||
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
|
||||
|
||||
async def async_get_engine(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> tts.Provider:
|
||||
"""Set up a mock speech component."""
|
||||
return MockTTSProvider()
|
||||
def __init__(self, *, async_get_engine, **kwargs):
|
||||
"""Initialize the tts platform."""
|
||||
super().__init__(**kwargs)
|
||||
self.async_get_engine = async_get_engine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_tts_provider(hass) -> MockTTSProvider:
|
||||
"""Mock TTS provider."""
|
||||
return MockTTSProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -169,10 +184,11 @@ def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(
|
||||
async def init_supporting_components(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_stt_provider_entity: MockSttProviderEntity,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
config_flow_fixture,
|
||||
init_cache_dir_side_effect, # noqa: F811
|
||||
mock_get_cache_files, # noqa: F811
|
||||
@ -210,7 +226,13 @@ async def init_components(
|
||||
async_unload_entry=async_unload_entry_init,
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "test.tts", MockTTS())
|
||||
mock_platform(
|
||||
hass,
|
||||
"test.tts",
|
||||
MockTTSPlatform(
|
||||
async_get_engine=AsyncMock(return_value=mock_tts_provider),
|
||||
),
|
||||
)
|
||||
mock_platform(
|
||||
hass,
|
||||
"test.stt",
|
||||
@ -224,7 +246,6 @@ async def init_components(
|
||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||
assert await async_setup_component(hass, "media_source", {})
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
config_entry = MockConfigEntry(domain="test")
|
||||
config_entry.add_to_hass(hass)
|
||||
@ -232,6 +253,13 @@ async def init_components(
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_components(hass: HomeAssistant, init_supporting_components):
|
||||
"""Initialize relevant components with empty configs."""
|
||||
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
|
||||
"""Return pipeline storage collection."""
|
||||
|
@ -4,7 +4,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'language': 'en',
|
||||
'pipeline': 'en',
|
||||
'pipeline': 'Home Assistant',
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
@ -34,7 +34,7 @@
|
||||
'data': dict({
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
@ -64,18 +64,18 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'engine': 'test',
|
||||
'language': 'en',
|
||||
'language': 'en-US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': None,
|
||||
'voice': 'james_earl_jones',
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||
|
@ -2,7 +2,7 @@
|
||||
# name: test_audio_pipeline
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': 'en',
|
||||
'pipeline': 'Home Assistant',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
@ -33,7 +33,7 @@
|
||||
dict({
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.4
|
||||
@ -61,24 +61,24 @@
|
||||
# name: test_audio_pipeline.5
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en',
|
||||
'language': 'en-US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': None,
|
||||
'voice': 'james_earl_jones',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': 'en',
|
||||
'pipeline': 'Home Assistant',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
@ -109,7 +109,7 @@
|
||||
dict({
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'test transcript',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.4
|
||||
@ -137,24 +137,24 @@
|
||||
# name: test_audio_pipeline_debug.5
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en',
|
||||
'language': 'en-US',
|
||||
'tts_input': "Sorry, I couldn't understand that",
|
||||
'voice': None,
|
||||
'voice': 'james_earl_jones',
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_failed
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': 'en',
|
||||
'pipeline': 'Home Assistant',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 30,
|
||||
@ -165,13 +165,13 @@
|
||||
dict({
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_timeout
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': 'en',
|
||||
'pipeline': 'Home Assistant',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 0.1,
|
||||
@ -182,7 +182,7 @@
|
||||
dict({
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_intent_timeout.2
|
||||
@ -217,7 +217,7 @@
|
||||
# name: test_stt_stream_failed
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': 'en',
|
||||
'pipeline': 'Home Assistant',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': 1,
|
||||
'timeout': 30,
|
||||
@ -240,7 +240,7 @@
|
||||
# name: test_text_only_pipeline
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': 'en',
|
||||
'pipeline': 'Home Assistant',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 30,
|
||||
@ -251,7 +251,7 @@
|
||||
dict({
|
||||
'engine': 'homeassistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
'language': None,
|
||||
'language': 'en',
|
||||
})
|
||||
# ---
|
||||
# name: test_text_only_pipeline.2
|
||||
@ -285,7 +285,7 @@
|
||||
# name: test_tts_failed
|
||||
dict({
|
||||
'language': 'en',
|
||||
'pipeline': 'en',
|
||||
'pipeline': 'Home Assistant',
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 30,
|
||||
@ -295,8 +295,8 @@
|
||||
# name: test_tts_failed.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en',
|
||||
'language': 'en-US',
|
||||
'tts_input': 'Lights are on.',
|
||||
'voice': None,
|
||||
'voice': 'james_earl_jones',
|
||||
})
|
||||
# ---
|
||||
|
@ -1,10 +1,14 @@
|
||||
"""Websocket tests for Voice Assistant integration."""
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
STORAGE_KEY,
|
||||
STORAGE_VERSION,
|
||||
Pipeline,
|
||||
PipelineData,
|
||||
PipelineStorageCollection,
|
||||
async_get_pipeline,
|
||||
@ -13,7 +17,10 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import flush_store
|
||||
from . import MANY_LANGUAGES
|
||||
from .conftest import MockSttPlatform, MockSttProvider, MockTTSPlatform, MockTTSProvider
|
||||
|
||||
from tests.common import MockModule, flush_store, mock_integration, mock_platform
|
||||
|
||||
|
||||
async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
||||
@ -60,17 +67,17 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
||||
store1 = pipeline_data.pipeline_store
|
||||
for pipeline in pipelines:
|
||||
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
|
||||
assert len(store1.data) == 3
|
||||
assert len(store1.data) == 4 # 3 manually created plus a default pipeline
|
||||
assert store1.async_get_preferred_item() == list(store1.data)[0]
|
||||
|
||||
await store1.async_delete_item(pipeline_ids[1])
|
||||
assert len(store1.data) == 2
|
||||
assert len(store1.data) == 3
|
||||
|
||||
store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY))
|
||||
await flush_store(store1.store)
|
||||
await store2.async_load()
|
||||
|
||||
assert len(store2.data) == 2
|
||||
assert len(store2.data) == 3
|
||||
|
||||
assert store1.data is not store2.data
|
||||
assert store1.data == store2.data
|
||||
@ -142,16 +149,221 @@ async def test_get_pipeline(hass: HomeAssistant) -> None:
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 0
|
||||
|
||||
# Test a pipeline is created
|
||||
pipeline = await async_get_pipeline(hass, None)
|
||||
assert len(store.data) == 1
|
||||
|
||||
# Test we get the same pipeline again
|
||||
assert pipeline is await async_get_pipeline(hass, None)
|
||||
assert len(store.data) == 1
|
||||
# Test we get the preferred pipeline if none is specified
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline.id == store.async_get_preferred_item()
|
||||
|
||||
# Test getting a specific pipeline
|
||||
assert pipeline is await async_get_pipeline(hass, pipeline.id)
|
||||
assert pipeline is async_get_pipeline(hass, pipeline.id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("ha_language", "ha_country", "conv_language", "pipeline_language"),
|
||||
[
|
||||
("en", None, "en", "en"),
|
||||
("de", "de", "de", "de"),
|
||||
("de", "ch", "de-CH", "de"),
|
||||
("en", "us", "en", "en"),
|
||||
("en", "uk", "en", "en"),
|
||||
("pt", "pt", "pt", "pt"),
|
||||
("pt", "br", "pt-br", "pt"),
|
||||
],
|
||||
)
|
||||
async def test_default_pipeline_no_stt_tts(
|
||||
hass: HomeAssistant,
|
||||
ha_language: str,
|
||||
ha_country: str | None,
|
||||
conv_language: str,
|
||||
pipeline_language: str,
|
||||
) -> None:
|
||||
"""Test async_get_pipeline."""
|
||||
hass.config.country = ha_country
|
||||
hass.config.language = ha_language
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 1
|
||||
|
||||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_language=conv_language,
|
||||
id=pipeline.id,
|
||||
language=pipeline_language,
|
||||
name="Home Assistant",
|
||||
stt_engine=None,
|
||||
stt_language=None,
|
||||
tts_engine=None,
|
||||
tts_language=None,
|
||||
tts_voice=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"ha_language",
|
||||
"ha_country",
|
||||
"conv_language",
|
||||
"pipeline_language",
|
||||
"stt_language",
|
||||
"tts_language",
|
||||
),
|
||||
[
|
||||
("en", None, "en", "en", "en", "en"),
|
||||
("de", "de", "de", "de", "de", "de"),
|
||||
("de", "ch", "de-CH", "de", "de-CH", "de-CH"),
|
||||
("en", "us", "en", "en", "en", "en"),
|
||||
("en", "uk", "en", "en", "en", "en"),
|
||||
("pt", "pt", "pt", "pt", "pt", "pt"),
|
||||
("pt", "br", "pt-br", "pt", "pt-br", "pt-br"),
|
||||
],
|
||||
)
|
||||
async def test_default_pipeline(
|
||||
hass: HomeAssistant,
|
||||
init_supporting_components,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
ha_language: str,
|
||||
ha_country: str | None,
|
||||
conv_language: str,
|
||||
pipeline_language: str,
|
||||
stt_language: str,
|
||||
tts_language: str,
|
||||
) -> None:
|
||||
"""Test async_get_pipeline."""
|
||||
hass.config.country = ha_country
|
||||
hass.config.language = ha_language
|
||||
|
||||
with patch.object(
|
||||
mock_stt_provider, "_supported_languages", MANY_LANGUAGES
|
||||
), patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES):
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 1
|
||||
|
||||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_language=conv_language,
|
||||
id=pipeline.id,
|
||||
language=pipeline_language,
|
||||
name="Home Assistant",
|
||||
stt_engine="test",
|
||||
stt_language=stt_language,
|
||||
tts_engine="test",
|
||||
tts_language=tts_language,
|
||||
tts_voice=None,
|
||||
)
|
||||
|
||||
|
||||
async def test_default_pipeline_unsupported_stt_language(
|
||||
hass: HomeAssistant,
|
||||
init_supporting_components,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
) -> None:
|
||||
"""Test async_get_pipeline."""
|
||||
with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]):
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 1
|
||||
|
||||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_language="en",
|
||||
id=pipeline.id,
|
||||
language="en",
|
||||
name="Home Assistant",
|
||||
stt_engine=None,
|
||||
stt_language=None,
|
||||
tts_engine="test",
|
||||
tts_language="en-US",
|
||||
tts_voice="james_earl_jones",
|
||||
)
|
||||
|
||||
|
||||
async def test_default_pipeline_unsupported_tts_language(
|
||||
hass: HomeAssistant,
|
||||
init_supporting_components,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
) -> None:
|
||||
"""Test async_get_pipeline."""
|
||||
with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]):
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 1
|
||||
|
||||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_language="en",
|
||||
id=pipeline.id,
|
||||
language="en",
|
||||
name="Home Assistant",
|
||||
stt_engine="test",
|
||||
stt_language="en-US",
|
||||
tts_engine=None,
|
||||
tts_language=None,
|
||||
tts_voice=None,
|
||||
)
|
||||
|
||||
|
||||
async def test_default_pipeline_cloud(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
) -> None:
|
||||
"""Test async_get_pipeline."""
|
||||
|
||||
mock_integration(hass, MockModule("cloud"))
|
||||
mock_platform(
|
||||
hass,
|
||||
"cloud.tts",
|
||||
MockTTSPlatform(
|
||||
async_get_engine=AsyncMock(return_value=mock_tts_provider),
|
||||
),
|
||||
)
|
||||
mock_platform(
|
||||
hass,
|
||||
"cloud.stt",
|
||||
MockSttPlatform(
|
||||
async_get_engine=AsyncMock(return_value=mock_stt_provider),
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "test.config_flow")
|
||||
|
||||
assert await async_setup_component(hass, "tts", {"tts": {"platform": "cloud"}})
|
||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": "cloud"}})
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 1
|
||||
|
||||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_language="en",
|
||||
id=pipeline.id,
|
||||
language="en",
|
||||
name="Home Assistant Cloud",
|
||||
stt_engine="cloud",
|
||||
stt_language="en-US",
|
||||
tts_engine="cloud",
|
||||
tts_language="en-US",
|
||||
tts_voice="james_earl_jones",
|
||||
)
|
||||
|
@ -92,6 +92,7 @@ async def test_select_entity_changing_pipelines(
|
||||
assert state.state == "preferred"
|
||||
assert state.attributes["options"] == [
|
||||
"preferred",
|
||||
"Home Assistant",
|
||||
pipeline_1.name,
|
||||
pipeline_2.name,
|
||||
]
|
||||
@ -122,4 +123,8 @@ async def test_select_entity_changing_pipelines(
|
||||
|
||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||
assert state.state == "preferred"
|
||||
assert state.attributes["options"] == ["preferred", pipeline_1.name]
|
||||
assert state.attributes["options"] == [
|
||||
"preferred",
|
||||
"Home Assistant",
|
||||
pipeline_1.name,
|
||||
]
|
||||
|
@ -610,7 +610,7 @@ async def test_add_pipeline(
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
}
|
||||
|
||||
assert len(pipeline_store.data) == 1
|
||||
assert len(pipeline_store.data) == 2
|
||||
pipeline = pipeline_store.data[msg["result"]["id"]]
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="test_conversation_engine",
|
||||
@ -643,6 +643,7 @@ async def test_add_pipeline_missing_language(
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
@ -660,7 +661,7 @@ async def test_add_pipeline_missing_language(
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert len(pipeline_store.data) == 0
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
@ -678,7 +679,7 @@ async def test_add_pipeline_missing_language(
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert len(pipeline_store.data) == 0
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
|
||||
async def test_delete_pipeline(
|
||||
@ -725,7 +726,16 @@ async def test_delete_pipeline(
|
||||
assert msg["success"]
|
||||
pipeline_id_2 = msg["result"]["id"]
|
||||
|
||||
assert len(pipeline_store.data) == 2
|
||||
assert len(pipeline_store.data) == 3
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/set_preferred",
|
||||
"pipeline_id": pipeline_id_1,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
@ -748,7 +758,7 @@ async def test_delete_pipeline(
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
assert len(pipeline_store.data) == 2
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
@ -778,10 +788,18 @@ async def test_get_pipeline(
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"] == {
|
||||
"code": "not_found",
|
||||
"message": "Unable to find pipeline_id None",
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": "en",
|
||||
"id": ANY,
|
||||
"language": "en",
|
||||
"name": "Home Assistant",
|
||||
"stt_engine": "test",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
"tts_voice": "james_earl_jones",
|
||||
}
|
||||
|
||||
await client.send_json_auto_id(
|
||||
@ -814,27 +832,7 @@ async def test_get_pipeline(
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/get",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"conversation_language": "test_language",
|
||||
"id": pipeline_id,
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"stt_language": "test_language",
|
||||
"tts_engine": "test_tts_engine",
|
||||
"tts_language": "test_language",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
}
|
||||
assert len(pipeline_store.data) == 2
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
@ -863,31 +861,7 @@ async def test_list_pipelines(
|
||||
) -> None:
|
||||
"""Test we can list pipelines."""
|
||||
client = await hass_ws_client(hass)
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
|
||||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"pipelines": [], "preferred_pipeline": None}
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"conversation_language": "test_language",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"stt_language": "test_language",
|
||||
"tts_engine": "test_tts_engine",
|
||||
"tts_language": "test_language",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
hass.data[DOMAIN]
|
||||
|
||||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||
msg = await client.receive_json()
|
||||
@ -895,16 +869,16 @@ async def test_list_pipelines(
|
||||
assert msg["result"] == {
|
||||
"pipelines": [
|
||||
{
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"conversation_language": "test_language",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": "en",
|
||||
"id": ANY,
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"stt_language": "test_language",
|
||||
"tts_engine": "test_tts_engine",
|
||||
"tts_language": "test_language",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
"language": "en",
|
||||
"name": "Home Assistant",
|
||||
"stt_engine": "test",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
"tts_voice": "james_earl_jones",
|
||||
}
|
||||
],
|
||||
"preferred_pipeline": ANY,
|
||||
@ -958,7 +932,7 @@ async def test_update_pipeline(
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
assert len(pipeline_store.data) == 1
|
||||
assert len(pipeline_store.data) == 2
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
@ -990,7 +964,7 @@ async def test_update_pipeline(
|
||||
"tts_voice": "new_tts_voice",
|
||||
}
|
||||
|
||||
assert len(pipeline_store.data) == 1
|
||||
assert len(pipeline_store.data) == 2
|
||||
pipeline = pipeline_store.data[pipeline_id]
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="new_conversation_engine",
|
||||
@ -1076,36 +1050,18 @@ async def test_set_preferred_pipeline(
|
||||
assert msg["success"]
|
||||
pipeline_id_1 = msg["result"]["id"]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "test_conversation_engine",
|
||||
"conversation_language": "test_language",
|
||||
"language": "test_language",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test_stt_engine",
|
||||
"stt_language": "test_language",
|
||||
"tts_engine": "test_tts_engine",
|
||||
"tts_language": "test_language",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id_2 = msg["result"]["id"]
|
||||
|
||||
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
|
||||
assert pipeline_store.async_get_preferred_item() != pipeline_id_1
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/set_preferred",
|
||||
"pipeline_id": pipeline_id_2,
|
||||
"pipeline_id": pipeline_id_1,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
assert pipeline_store.async_get_preferred_item() == pipeline_id_2
|
||||
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
|
||||
|
||||
|
||||
async def test_set_preferred_pipeline_wrong_id(
|
||||
|
Loading…
x
Reference in New Issue
Block a user