mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Create a default assist pipeline on start (#91947)
* Create a default assist pipeline on start * Minor adjustments * Address review comments * Remove tts.async_get_agent * Fix bugs, improve test coverage
This commit is contained in:
parent
4b619f7251
commit
b601fb17d3
@ -51,7 +51,7 @@ async def async_pipeline_from_audio_stream(
|
|||||||
tts_audio_output: str | None = None,
|
tts_audio_output: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Create an audio pipeline from an audio stream."""
|
"""Create an audio pipeline from an audio stream."""
|
||||||
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
|
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
raise PipelineNotFound(
|
raise PipelineNotFound(
|
||||||
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
"pipeline_not_found", f"Pipeline {pipeline_id} not found"
|
||||||
|
@ -24,7 +24,11 @@ from homeassistant.helpers.collection import (
|
|||||||
StorageCollectionWebsocket,
|
StorageCollectionWebsocket,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.storage import Store
|
from homeassistant.helpers.storage import Store
|
||||||
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
from homeassistant.util import (
|
||||||
|
dt as dt_util,
|
||||||
|
language as language_util,
|
||||||
|
ulid as ulid_util,
|
||||||
|
)
|
||||||
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
from homeassistant.util.limited_size_dict import LimitedSizeDict
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
@ -71,37 +75,109 @@ STORED_PIPELINE_RUNS = 10
|
|||||||
SAVE_DELAY = 10
|
SAVE_DELAY = 10
|
||||||
|
|
||||||
|
|
||||||
async def async_get_pipeline(
|
async def _async_create_default_pipeline(
|
||||||
|
hass: HomeAssistant, pipeline_store: PipelineStorageCollection
|
||||||
|
) -> Pipeline:
|
||||||
|
"""Create a default pipeline.
|
||||||
|
|
||||||
|
The default pipeline will use the homeassistant conversation agent and the
|
||||||
|
default stt / tts engines.
|
||||||
|
"""
|
||||||
|
conversation_language = "en"
|
||||||
|
pipeline_language = "en"
|
||||||
|
pipeline_name = "Home Assistant"
|
||||||
|
stt_engine_id = None
|
||||||
|
stt_language = None
|
||||||
|
tts_engine_id = None
|
||||||
|
tts_language = None
|
||||||
|
tts_voice = None
|
||||||
|
|
||||||
|
# Find a matching language supported by the Home Assistant conversation agent
|
||||||
|
conversation_languages = language_util.matches(
|
||||||
|
hass.config.language,
|
||||||
|
await conversation.async_get_conversation_languages(
|
||||||
|
hass, conversation.HOME_ASSISTANT_AGENT
|
||||||
|
),
|
||||||
|
country=hass.config.country,
|
||||||
|
)
|
||||||
|
if conversation_languages:
|
||||||
|
pipeline_language = hass.config.language
|
||||||
|
conversation_language = conversation_languages[0]
|
||||||
|
|
||||||
|
if (stt_engine_id := stt.async_default_engine(hass)) is not None and (
|
||||||
|
stt_engine := stt.async_get_speech_to_text_engine(
|
||||||
|
hass,
|
||||||
|
stt_engine_id,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
stt_languages = language_util.matches(
|
||||||
|
pipeline_language,
|
||||||
|
stt_engine.supported_languages,
|
||||||
|
country=hass.config.country,
|
||||||
|
)
|
||||||
|
if stt_languages:
|
||||||
|
stt_language = stt_languages[0]
|
||||||
|
else:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Speech to text engine '%s' does not support language '%s'",
|
||||||
|
stt_engine_id,
|
||||||
|
pipeline_language,
|
||||||
|
)
|
||||||
|
stt_engine_id = None
|
||||||
|
|
||||||
|
if (tts_engine_id := tts.async_default_engine(hass)) is not None and (
|
||||||
|
tts_engine := tts.get_engine_instance(
|
||||||
|
hass,
|
||||||
|
tts_engine_id,
|
||||||
|
)
|
||||||
|
):
|
||||||
|
tts_languages = language_util.matches(
|
||||||
|
pipeline_language,
|
||||||
|
tts_engine.supported_languages,
|
||||||
|
country=hass.config.country,
|
||||||
|
)
|
||||||
|
if tts_languages:
|
||||||
|
tts_language = tts_languages[0]
|
||||||
|
tts_voices = tts_engine.async_get_supported_voices(tts_language)
|
||||||
|
if tts_voices:
|
||||||
|
tts_voice = tts_voices[0].voice_id
|
||||||
|
else:
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Text to speech engine '%s' does not support language '%s'",
|
||||||
|
tts_engine_id,
|
||||||
|
pipeline_language,
|
||||||
|
)
|
||||||
|
tts_engine_id = None
|
||||||
|
|
||||||
|
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
|
||||||
|
pipeline_name = "Home Assistant Cloud"
|
||||||
|
|
||||||
|
return await pipeline_store.async_create_item(
|
||||||
|
{
|
||||||
|
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
||||||
|
"conversation_language": conversation_language,
|
||||||
|
"language": hass.config.language,
|
||||||
|
"name": pipeline_name,
|
||||||
|
"stt_engine": stt_engine_id,
|
||||||
|
"stt_language": stt_language,
|
||||||
|
"tts_engine": tts_engine_id,
|
||||||
|
"tts_language": tts_language,
|
||||||
|
"tts_voice": tts_voice,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_pipeline(
|
||||||
hass: HomeAssistant, pipeline_id: str | None = None
|
hass: HomeAssistant, pipeline_id: str | None = None
|
||||||
) -> Pipeline | None:
|
) -> Pipeline | None:
|
||||||
"""Get a pipeline by id or create one for a language."""
|
"""Get a pipeline by id or the preferred pipeline."""
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
|
||||||
if pipeline_id is None:
|
if pipeline_id is None:
|
||||||
# A pipeline was not specified, use the preferred one
|
# A pipeline was not specified, use the preferred one
|
||||||
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
pipeline_id = pipeline_data.pipeline_store.async_get_preferred_item()
|
||||||
|
|
||||||
if pipeline_id is None:
|
|
||||||
# There's no preferred pipeline, construct a pipeline for the
|
|
||||||
# configured language
|
|
||||||
stt_engine = stt.async_default_provider(hass)
|
|
||||||
stt_language = hass.config.language if stt_engine else None
|
|
||||||
tts_engine = tts.async_default_engine(hass)
|
|
||||||
tts_language = hass.config.language if tts_engine else None
|
|
||||||
return await pipeline_data.pipeline_store.async_create_item(
|
|
||||||
{
|
|
||||||
"conversation_engine": None,
|
|
||||||
"conversation_language": None,
|
|
||||||
"language": hass.config.language,
|
|
||||||
"name": hass.config.language,
|
|
||||||
"stt_engine": stt_engine,
|
|
||||||
"stt_language": stt_language,
|
|
||||||
"tts_engine": tts_engine,
|
|
||||||
"tts_language": tts_language,
|
|
||||||
"tts_voice": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
return pipeline_data.pipeline_store.data.get(pipeline_id)
|
||||||
|
|
||||||
|
|
||||||
@ -635,7 +711,7 @@ class PipelinePreferred(CollectionError):
|
|||||||
class SerializedPipelineStorageCollection(SerializedStorageCollection):
|
class SerializedPipelineStorageCollection(SerializedStorageCollection):
|
||||||
"""Serialized pipeline storage collection."""
|
"""Serialized pipeline storage collection."""
|
||||||
|
|
||||||
preferred_item: str | None
|
preferred_item: str
|
||||||
|
|
||||||
|
|
||||||
class PipelineStorageCollection(
|
class PipelineStorageCollection(
|
||||||
@ -643,11 +719,13 @@ class PipelineStorageCollection(
|
|||||||
):
|
):
|
||||||
"""Pipeline storage collection."""
|
"""Pipeline storage collection."""
|
||||||
|
|
||||||
_preferred_item: str | None = None
|
_preferred_item: str
|
||||||
|
|
||||||
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
|
async def _async_load_data(self) -> SerializedPipelineStorageCollection | None:
|
||||||
"""Load the data."""
|
"""Load the data."""
|
||||||
if not (data := await super()._async_load_data()):
|
if not (data := await super()._async_load_data()):
|
||||||
|
pipeline = await _async_create_default_pipeline(self.hass, self)
|
||||||
|
self._preferred_item = pipeline.id
|
||||||
return data
|
return data
|
||||||
|
|
||||||
self._preferred_item = data["preferred_item"]
|
self._preferred_item = data["preferred_item"]
|
||||||
@ -671,8 +749,6 @@ class PipelineStorageCollection(
|
|||||||
|
|
||||||
def _create_item(self, item_id: str, data: dict) -> Pipeline:
|
def _create_item(self, item_id: str, data: dict) -> Pipeline:
|
||||||
"""Create an item from validated config."""
|
"""Create an item from validated config."""
|
||||||
if self._preferred_item is None:
|
|
||||||
self._preferred_item = item_id
|
|
||||||
return Pipeline(id=item_id, **data)
|
return Pipeline(id=item_id, **data)
|
||||||
|
|
||||||
def _deserialize_item(self, data: dict) -> Pipeline:
|
def _deserialize_item(self, data: dict) -> Pipeline:
|
||||||
@ -690,7 +766,7 @@ class PipelineStorageCollection(
|
|||||||
await super().async_delete_item(item_id)
|
await super().async_delete_item(item_id)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_preferred_item(self) -> str | None:
|
def async_get_preferred_item(self) -> str:
|
||||||
"""Get the id of the preferred item."""
|
"""Get the id of the preferred item."""
|
||||||
return self._preferred_item
|
return self._preferred_item
|
||||||
|
|
||||||
|
@ -85,7 +85,7 @@ async def websocket_run(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Run a pipeline."""
|
"""Run a pipeline."""
|
||||||
pipeline_id = msg.get("pipeline")
|
pipeline_id = msg.get("pipeline")
|
||||||
pipeline = await async_get_pipeline(hass, pipeline_id=pipeline_id)
|
pipeline = async_get_pipeline(hass, pipeline_id=pipeline_id)
|
||||||
if pipeline is None:
|
if pipeline is None:
|
||||||
connection.send_error(
|
connection.send_error(
|
||||||
msg["id"],
|
msg["id"],
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -117,14 +118,25 @@ def async_unset_agent(
|
|||||||
|
|
||||||
|
|
||||||
async def async_get_conversation_languages(
|
async def async_get_conversation_languages(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant, agent_id: str | None = None
|
||||||
) -> set[str] | Literal["*"]:
|
) -> set[str] | Literal["*"]:
|
||||||
"""Return a set with the union of languages supported by conversation agents."""
|
"""Return languages supported by conversation agents.
|
||||||
|
|
||||||
|
If an agent is specified, returns a set of languages supported by that agent.
|
||||||
|
If no agent is specified, return a set with the union of languages supported by
|
||||||
|
all conversation agents.
|
||||||
|
"""
|
||||||
agent_manager = _get_agent_manager(hass)
|
agent_manager = _get_agent_manager(hass)
|
||||||
languages = set()
|
languages = set()
|
||||||
|
|
||||||
for agent_info in agent_manager.async_get_agent_info():
|
agent_ids: Iterable[str]
|
||||||
agent = await agent_manager.async_get_agent(agent_info.id)
|
if agent_id is None:
|
||||||
|
agent_ids = iter(info.id for info in agent_manager.async_get_agent_info())
|
||||||
|
else:
|
||||||
|
agent_ids = (agent_id,)
|
||||||
|
|
||||||
|
for _agent_id in agent_ids:
|
||||||
|
agent = await agent_manager.async_get_agent(_agent_id)
|
||||||
if agent.supported_languages == MATCH_ALL:
|
if agent.supported_languages == MATCH_ALL:
|
||||||
return MATCH_ALL
|
return MATCH_ALL
|
||||||
for language_tag in agent.supported_languages:
|
for language_tag in agent.supported_languages:
|
||||||
|
@ -1 +1,56 @@
|
|||||||
"""Tests for the Voice Assistant integration."""
|
"""Tests for the Voice Assistant integration."""
|
||||||
|
|
||||||
|
MANY_LANGUAGES = [
|
||||||
|
"ar",
|
||||||
|
"bg",
|
||||||
|
"bn",
|
||||||
|
"ca",
|
||||||
|
"cs",
|
||||||
|
"da",
|
||||||
|
"de",
|
||||||
|
"de-CH",
|
||||||
|
"el",
|
||||||
|
"en",
|
||||||
|
"es",
|
||||||
|
"fa",
|
||||||
|
"fi",
|
||||||
|
"fr",
|
||||||
|
"fr-CA",
|
||||||
|
"gl",
|
||||||
|
"gu",
|
||||||
|
"he",
|
||||||
|
"hi",
|
||||||
|
"hr",
|
||||||
|
"hu",
|
||||||
|
"id",
|
||||||
|
"is",
|
||||||
|
"it",
|
||||||
|
"ka",
|
||||||
|
"kn",
|
||||||
|
"lb",
|
||||||
|
"lt",
|
||||||
|
"lv",
|
||||||
|
"ml",
|
||||||
|
"mn",
|
||||||
|
"ms",
|
||||||
|
"nb",
|
||||||
|
"nl",
|
||||||
|
"pl",
|
||||||
|
"pt",
|
||||||
|
"pt-br",
|
||||||
|
"ro",
|
||||||
|
"ru",
|
||||||
|
"sk",
|
||||||
|
"sl",
|
||||||
|
"sr",
|
||||||
|
"sv",
|
||||||
|
"sw",
|
||||||
|
"te",
|
||||||
|
"tr",
|
||||||
|
"uk",
|
||||||
|
"ur",
|
||||||
|
"vi",
|
||||||
|
"zh-cn",
|
||||||
|
"zh-hk",
|
||||||
|
"zh-tw",
|
||||||
|
]
|
||||||
|
@ -11,9 +11,8 @@ from homeassistant.components import stt, tts
|
|||||||
from homeassistant.components.assist_pipeline import DOMAIN
|
from homeassistant.components.assist_pipeline import DOMAIN
|
||||||
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
from homeassistant.components.assist_pipeline.pipeline import PipelineStorageCollection
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
from homeassistant.config_entries import ConfigEntry, ConfigFlow
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import (
|
from tests.common import (
|
||||||
@ -36,6 +35,8 @@ _TRANSCRIPT = "test transcript"
|
|||||||
class BaseProvider:
|
class BaseProvider:
|
||||||
"""Mock STT provider."""
|
"""Mock STT provider."""
|
||||||
|
|
||||||
|
_supported_languages = ["en-US"]
|
||||||
|
|
||||||
def __init__(self, text: str) -> None:
|
def __init__(self, text: str) -> None:
|
||||||
"""Init test provider."""
|
"""Init test provider."""
|
||||||
self.text = text
|
self.text = text
|
||||||
@ -44,7 +45,7 @@ class BaseProvider:
|
|||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
return ["en-US"]
|
return self._supported_languages
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_formats(self) -> list[stt.AudioFormats]:
|
def supported_formats(self) -> list[stt.AudioFormats]:
|
||||||
@ -96,6 +97,13 @@ class MockTTSProvider(tts.Provider):
|
|||||||
"""Mock TTS provider."""
|
"""Mock TTS provider."""
|
||||||
|
|
||||||
name = "Test"
|
name = "Test"
|
||||||
|
_supported_languages = ["en-US"]
|
||||||
|
_supported_voices = {
|
||||||
|
"en-US": [
|
||||||
|
tts.Voice("james_earl_jones", "James Earl Jones"),
|
||||||
|
tts.Voice("fran_drescher", "Fran Drescher"),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_language(self) -> str:
|
def default_language(self) -> str:
|
||||||
@ -105,7 +113,12 @@ class MockTTSProvider(tts.Provider):
|
|||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return list of supported languages."""
|
"""Return list of supported languages."""
|
||||||
return ["en-US"]
|
return self._supported_languages
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
||||||
|
"""Return a list of supported voices for a language."""
|
||||||
|
return self._supported_voices.get(language)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supported_options(self) -> list[str]:
|
def supported_options(self) -> list[str]:
|
||||||
@ -119,19 +132,21 @@ class MockTTSProvider(tts.Provider):
|
|||||||
return ("mp3", b"")
|
return ("mp3", b"")
|
||||||
|
|
||||||
|
|
||||||
class MockTTS(MockPlatform):
|
class MockTTSPlatform(MockPlatform):
|
||||||
"""A mock TTS platform."""
|
"""A mock TTS platform."""
|
||||||
|
|
||||||
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
|
PLATFORM_SCHEMA = tts.PLATFORM_SCHEMA
|
||||||
|
|
||||||
async def async_get_engine(
|
def __init__(self, *, async_get_engine, **kwargs):
|
||||||
self,
|
"""Initialize the tts platform."""
|
||||||
hass: HomeAssistant,
|
super().__init__(**kwargs)
|
||||||
config: ConfigType,
|
self.async_get_engine = async_get_engine
|
||||||
discovery_info: DiscoveryInfoType | None = None,
|
|
||||||
) -> tts.Provider:
|
|
||||||
"""Set up a mock speech component."""
|
@pytest.fixture
|
||||||
return MockTTSProvider()
|
async def mock_tts_provider(hass) -> MockTTSProvider:
|
||||||
|
"""Mock TTS provider."""
|
||||||
|
return MockTTSProvider()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -169,10 +184,11 @@ def config_flow_fixture(hass: HomeAssistant) -> Generator[None, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def init_components(
|
async def init_supporting_components(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSttProvider,
|
||||||
mock_stt_provider_entity: MockSttProviderEntity,
|
mock_stt_provider_entity: MockSttProviderEntity,
|
||||||
|
mock_tts_provider: MockTTSProvider,
|
||||||
config_flow_fixture,
|
config_flow_fixture,
|
||||||
init_cache_dir_side_effect, # noqa: F811
|
init_cache_dir_side_effect, # noqa: F811
|
||||||
mock_get_cache_files, # noqa: F811
|
mock_get_cache_files, # noqa: F811
|
||||||
@ -210,7 +226,13 @@ async def init_components(
|
|||||||
async_unload_entry=async_unload_entry_init,
|
async_unload_entry=async_unload_entry_init,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
mock_platform(hass, "test.tts", MockTTS())
|
mock_platform(
|
||||||
|
hass,
|
||||||
|
"test.tts",
|
||||||
|
MockTTSPlatform(
|
||||||
|
async_get_engine=AsyncMock(return_value=mock_tts_provider),
|
||||||
|
),
|
||||||
|
)
|
||||||
mock_platform(
|
mock_platform(
|
||||||
hass,
|
hass,
|
||||||
"test.stt",
|
"test.stt",
|
||||||
@ -224,7 +246,6 @@ async def init_components(
|
|||||||
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
assert await async_setup_component(hass, tts.DOMAIN, {"tts": {"platform": "test"}})
|
||||||
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
assert await async_setup_component(hass, stt.DOMAIN, {"stt": {"platform": "test"}})
|
||||||
assert await async_setup_component(hass, "media_source", {})
|
assert await async_setup_component(hass, "media_source", {})
|
||||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
|
||||||
|
|
||||||
config_entry = MockConfigEntry(domain="test")
|
config_entry = MockConfigEntry(domain="test")
|
||||||
config_entry.add_to_hass(hass)
|
config_entry.add_to_hass(hass)
|
||||||
@ -232,6 +253,13 @@ async def init_components(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_components(hass: HomeAssistant, init_supporting_components):
|
||||||
|
"""Initialize relevant components with empty configs."""
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
|
def pipeline_storage(hass: HomeAssistant, init_components) -> PipelineStorageCollection:
|
||||||
"""Return pipeline storage collection."""
|
"""Return pipeline storage collection."""
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': 'en',
|
'pipeline': 'Home Assistant',
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
}),
|
}),
|
||||||
@ -34,7 +34,7 @@
|
|||||||
'data': dict({
|
'data': dict({
|
||||||
'engine': 'homeassistant',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
}),
|
}),
|
||||||
@ -64,18 +64,18 @@
|
|||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'engine': 'test',
|
'engine': 'test',
|
||||||
'language': 'en',
|
'language': 'en-US',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': None,
|
'voice': 'james_earl_jones',
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
'type': <PipelineEventType.TTS_START: 'tts-start'>,
|
||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
'type': <PipelineEventType.TTS_END: 'tts-end'>,
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
# name: test_audio_pipeline
|
# name: test_audio_pipeline
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': 'en',
|
'pipeline': 'Home Assistant',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
@ -33,7 +33,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'engine': 'homeassistant',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline.4
|
# name: test_audio_pipeline.4
|
||||||
@ -61,24 +61,24 @@
|
|||||||
# name: test_audio_pipeline.5
|
# name: test_audio_pipeline.5
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'test',
|
||||||
'language': 'en',
|
'language': 'en-US',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': None,
|
'voice': 'james_earl_jones',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline.6
|
# name: test_audio_pipeline.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_debug
|
# name: test_audio_pipeline_debug
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': 'en',
|
'pipeline': 'Home Assistant',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
@ -109,7 +109,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'engine': 'homeassistant',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_debug.4
|
# name: test_audio_pipeline_debug.4
|
||||||
@ -137,24 +137,24 @@
|
|||||||
# name: test_audio_pipeline_debug.5
|
# name: test_audio_pipeline_debug.5
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'test',
|
||||||
'language': 'en',
|
'language': 'en-US',
|
||||||
'tts_input': "Sorry, I couldn't understand that",
|
'tts_input': "Sorry, I couldn't understand that",
|
||||||
'voice': None,
|
'voice': 'james_earl_jones',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_debug.6
|
# name: test_audio_pipeline_debug.6
|
||||||
dict({
|
dict({
|
||||||
'tts_output': dict({
|
'tts_output': dict({
|
||||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en",
|
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&voice=james_earl_jones",
|
||||||
'mime_type': 'audio/mpeg',
|
'mime_type': 'audio/mpeg',
|
||||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3',
|
||||||
}),
|
}),
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_intent_failed
|
# name: test_intent_failed
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': 'en',
|
'pipeline': 'Home Assistant',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
@ -165,13 +165,13 @@
|
|||||||
dict({
|
dict({
|
||||||
'engine': 'homeassistant',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'Are the lights on?',
|
'intent_input': 'Are the lights on?',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_intent_timeout
|
# name: test_intent_timeout
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': 'en',
|
'pipeline': 'Home Assistant',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
'timeout': 0.1,
|
'timeout': 0.1,
|
||||||
@ -182,7 +182,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'engine': 'homeassistant',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'Are the lights on?',
|
'intent_input': 'Are the lights on?',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_intent_timeout.2
|
# name: test_intent_timeout.2
|
||||||
@ -217,7 +217,7 @@
|
|||||||
# name: test_stt_stream_failed
|
# name: test_stt_stream_failed
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': 'en',
|
'pipeline': 'Home Assistant',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': 1,
|
'stt_binary_handler_id': 1,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
@ -240,7 +240,7 @@
|
|||||||
# name: test_text_only_pipeline
|
# name: test_text_only_pipeline
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': 'en',
|
'pipeline': 'Home Assistant',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
@ -251,7 +251,7 @@
|
|||||||
dict({
|
dict({
|
||||||
'engine': 'homeassistant',
|
'engine': 'homeassistant',
|
||||||
'intent_input': 'Are the lights on?',
|
'intent_input': 'Are the lights on?',
|
||||||
'language': None,
|
'language': 'en',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
# name: test_text_only_pipeline.2
|
# name: test_text_only_pipeline.2
|
||||||
@ -285,7 +285,7 @@
|
|||||||
# name: test_tts_failed
|
# name: test_tts_failed
|
||||||
dict({
|
dict({
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': 'en',
|
'pipeline': 'Home Assistant',
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
'stt_binary_handler_id': None,
|
'stt_binary_handler_id': None,
|
||||||
'timeout': 30,
|
'timeout': 30,
|
||||||
@ -295,8 +295,8 @@
|
|||||||
# name: test_tts_failed.1
|
# name: test_tts_failed.1
|
||||||
dict({
|
dict({
|
||||||
'engine': 'test',
|
'engine': 'test',
|
||||||
'language': 'en',
|
'language': 'en-US',
|
||||||
'tts_input': 'Lights are on.',
|
'tts_input': 'Lights are on.',
|
||||||
'voice': None,
|
'voice': 'james_earl_jones',
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
"""Websocket tests for Voice Assistant integration."""
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||||
from homeassistant.components.assist_pipeline.pipeline import (
|
from homeassistant.components.assist_pipeline.pipeline import (
|
||||||
STORAGE_KEY,
|
STORAGE_KEY,
|
||||||
STORAGE_VERSION,
|
STORAGE_VERSION,
|
||||||
|
Pipeline,
|
||||||
PipelineData,
|
PipelineData,
|
||||||
PipelineStorageCollection,
|
PipelineStorageCollection,
|
||||||
async_get_pipeline,
|
async_get_pipeline,
|
||||||
@ -13,7 +17,10 @@ from homeassistant.core import HomeAssistant
|
|||||||
from homeassistant.helpers.storage import Store
|
from homeassistant.helpers.storage import Store
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from tests.common import flush_store
|
from . import MANY_LANGUAGES
|
||||||
|
from .conftest import MockSttPlatform, MockSttProvider, MockTTSPlatform, MockTTSProvider
|
||||||
|
|
||||||
|
from tests.common import MockModule, flush_store, mock_integration, mock_platform
|
||||||
|
|
||||||
|
|
||||||
async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
||||||
@ -60,17 +67,17 @@ async def test_load_datasets(hass: HomeAssistant, init_components) -> None:
|
|||||||
store1 = pipeline_data.pipeline_store
|
store1 = pipeline_data.pipeline_store
|
||||||
for pipeline in pipelines:
|
for pipeline in pipelines:
|
||||||
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
|
pipeline_ids.append((await store1.async_create_item(pipeline)).id)
|
||||||
assert len(store1.data) == 3
|
assert len(store1.data) == 4 # 3 manually created plus a default pipeline
|
||||||
assert store1.async_get_preferred_item() == list(store1.data)[0]
|
assert store1.async_get_preferred_item() == list(store1.data)[0]
|
||||||
|
|
||||||
await store1.async_delete_item(pipeline_ids[1])
|
await store1.async_delete_item(pipeline_ids[1])
|
||||||
assert len(store1.data) == 2
|
assert len(store1.data) == 3
|
||||||
|
|
||||||
store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY))
|
store2 = PipelineStorageCollection(Store(hass, STORAGE_VERSION, STORAGE_KEY))
|
||||||
await flush_store(store1.store)
|
await flush_store(store1.store)
|
||||||
await store2.async_load()
|
await store2.async_load()
|
||||||
|
|
||||||
assert len(store2.data) == 2
|
assert len(store2.data) == 3
|
||||||
|
|
||||||
assert store1.data is not store2.data
|
assert store1.data is not store2.data
|
||||||
assert store1.data == store2.data
|
assert store1.data == store2.data
|
||||||
@ -142,16 +149,221 @@ async def test_get_pipeline(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
store = pipeline_data.pipeline_store
|
store = pipeline_data.pipeline_store
|
||||||
assert len(store.data) == 0
|
|
||||||
|
|
||||||
# Test a pipeline is created
|
|
||||||
pipeline = await async_get_pipeline(hass, None)
|
|
||||||
assert len(store.data) == 1
|
assert len(store.data) == 1
|
||||||
|
|
||||||
# Test we get the same pipeline again
|
# Test we get the preferred pipeline if none is specified
|
||||||
assert pipeline is await async_get_pipeline(hass, None)
|
pipeline = async_get_pipeline(hass, None)
|
||||||
assert len(store.data) == 1
|
assert pipeline.id == store.async_get_preferred_item()
|
||||||
|
|
||||||
# Test getting a specific pipeline
|
# Test getting a specific pipeline
|
||||||
assert pipeline is await async_get_pipeline(hass, pipeline.id)
|
assert pipeline is async_get_pipeline(hass, pipeline.id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("ha_language", "ha_country", "conv_language", "pipeline_language"),
|
||||||
|
[
|
||||||
|
("en", None, "en", "en"),
|
||||||
|
("de", "de", "de", "de"),
|
||||||
|
("de", "ch", "de-CH", "de"),
|
||||||
|
("en", "us", "en", "en"),
|
||||||
|
("en", "uk", "en", "en"),
|
||||||
|
("pt", "pt", "pt", "pt"),
|
||||||
|
("pt", "br", "pt-br", "pt"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_default_pipeline_no_stt_tts(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
ha_language: str,
|
||||||
|
ha_country: str | None,
|
||||||
|
conv_language: str,
|
||||||
|
pipeline_language: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_get_pipeline."""
|
||||||
|
hass.config.country = ha_country
|
||||||
|
hass.config.language = ha_language
|
||||||
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
store = pipeline_data.pipeline_store
|
||||||
assert len(store.data) == 1
|
assert len(store.data) == 1
|
||||||
|
|
||||||
|
# Check the default pipeline
|
||||||
|
pipeline = async_get_pipeline(hass, None)
|
||||||
|
assert pipeline == Pipeline(
|
||||||
|
conversation_engine="homeassistant",
|
||||||
|
conversation_language=conv_language,
|
||||||
|
id=pipeline.id,
|
||||||
|
language=pipeline_language,
|
||||||
|
name="Home Assistant",
|
||||||
|
stt_engine=None,
|
||||||
|
stt_language=None,
|
||||||
|
tts_engine=None,
|
||||||
|
tts_language=None,
|
||||||
|
tts_voice=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
(
|
||||||
|
"ha_language",
|
||||||
|
"ha_country",
|
||||||
|
"conv_language",
|
||||||
|
"pipeline_language",
|
||||||
|
"stt_language",
|
||||||
|
"tts_language",
|
||||||
|
),
|
||||||
|
[
|
||||||
|
("en", None, "en", "en", "en", "en"),
|
||||||
|
("de", "de", "de", "de", "de", "de"),
|
||||||
|
("de", "ch", "de-CH", "de", "de-CH", "de-CH"),
|
||||||
|
("en", "us", "en", "en", "en", "en"),
|
||||||
|
("en", "uk", "en", "en", "en", "en"),
|
||||||
|
("pt", "pt", "pt", "pt", "pt", "pt"),
|
||||||
|
("pt", "br", "pt-br", "pt", "pt-br", "pt-br"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_default_pipeline(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_supporting_components,
|
||||||
|
mock_stt_provider: MockSttProvider,
|
||||||
|
mock_tts_provider: MockTTSProvider,
|
||||||
|
ha_language: str,
|
||||||
|
ha_country: str | None,
|
||||||
|
conv_language: str,
|
||||||
|
pipeline_language: str,
|
||||||
|
stt_language: str,
|
||||||
|
tts_language: str,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_get_pipeline."""
|
||||||
|
hass.config.country = ha_country
|
||||||
|
hass.config.language = ha_language
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
mock_stt_provider, "_supported_languages", MANY_LANGUAGES
|
||||||
|
), patch.object(mock_tts_provider, "_supported_languages", MANY_LANGUAGES):
|
||||||
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
store = pipeline_data.pipeline_store
|
||||||
|
assert len(store.data) == 1
|
||||||
|
|
||||||
|
# Check the default pipeline
|
||||||
|
pipeline = async_get_pipeline(hass, None)
|
||||||
|
assert pipeline == Pipeline(
|
||||||
|
conversation_engine="homeassistant",
|
||||||
|
conversation_language=conv_language,
|
||||||
|
id=pipeline.id,
|
||||||
|
language=pipeline_language,
|
||||||
|
name="Home Assistant",
|
||||||
|
stt_engine="test",
|
||||||
|
stt_language=stt_language,
|
||||||
|
tts_engine="test",
|
||||||
|
tts_language=tts_language,
|
||||||
|
tts_voice=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_default_pipeline_unsupported_stt_language(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_supporting_components,
|
||||||
|
mock_stt_provider: MockSttProvider,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_get_pipeline."""
|
||||||
|
with patch.object(mock_stt_provider, "_supported_languages", ["smurfish"]):
|
||||||
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
store = pipeline_data.pipeline_store
|
||||||
|
assert len(store.data) == 1
|
||||||
|
|
||||||
|
# Check the default pipeline
|
||||||
|
pipeline = async_get_pipeline(hass, None)
|
||||||
|
assert pipeline == Pipeline(
|
||||||
|
conversation_engine="homeassistant",
|
||||||
|
conversation_language="en",
|
||||||
|
id=pipeline.id,
|
||||||
|
language="en",
|
||||||
|
name="Home Assistant",
|
||||||
|
stt_engine=None,
|
||||||
|
stt_language=None,
|
||||||
|
tts_engine="test",
|
||||||
|
tts_language="en-US",
|
||||||
|
tts_voice="james_earl_jones",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_default_pipeline_unsupported_tts_language(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_supporting_components,
|
||||||
|
mock_tts_provider: MockTTSProvider,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_get_pipeline."""
|
||||||
|
with patch.object(mock_tts_provider, "_supported_languages", ["smurfish"]):
|
||||||
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
store = pipeline_data.pipeline_store
|
||||||
|
assert len(store.data) == 1
|
||||||
|
|
||||||
|
# Check the default pipeline
|
||||||
|
pipeline = async_get_pipeline(hass, None)
|
||||||
|
assert pipeline == Pipeline(
|
||||||
|
conversation_engine="homeassistant",
|
||||||
|
conversation_language="en",
|
||||||
|
id=pipeline.id,
|
||||||
|
language="en",
|
||||||
|
name="Home Assistant",
|
||||||
|
stt_engine="test",
|
||||||
|
stt_language="en-US",
|
||||||
|
tts_engine=None,
|
||||||
|
tts_language=None,
|
||||||
|
tts_voice=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_default_pipeline_cloud(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_stt_provider: MockSttProvider,
|
||||||
|
mock_tts_provider: MockTTSProvider,
|
||||||
|
) -> None:
|
||||||
|
"""Test async_get_pipeline."""
|
||||||
|
|
||||||
|
mock_integration(hass, MockModule("cloud"))
|
||||||
|
mock_platform(
|
||||||
|
hass,
|
||||||
|
"cloud.tts",
|
||||||
|
MockTTSPlatform(
|
||||||
|
async_get_engine=AsyncMock(return_value=mock_tts_provider),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_platform(
|
||||||
|
hass,
|
||||||
|
"cloud.stt",
|
||||||
|
MockSttPlatform(
|
||||||
|
async_get_engine=AsyncMock(return_value=mock_stt_provider),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
mock_platform(hass, "test.config_flow")
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, "tts", {"tts": {"platform": "cloud"}})
|
||||||
|
assert await async_setup_component(hass, "stt", {"stt": {"platform": "cloud"}})
|
||||||
|
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||||
|
|
||||||
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
|
store = pipeline_data.pipeline_store
|
||||||
|
assert len(store.data) == 1
|
||||||
|
|
||||||
|
# Check the default pipeline
|
||||||
|
pipeline = async_get_pipeline(hass, None)
|
||||||
|
assert pipeline == Pipeline(
|
||||||
|
conversation_engine="homeassistant",
|
||||||
|
conversation_language="en",
|
||||||
|
id=pipeline.id,
|
||||||
|
language="en",
|
||||||
|
name="Home Assistant Cloud",
|
||||||
|
stt_engine="cloud",
|
||||||
|
stt_language="en-US",
|
||||||
|
tts_engine="cloud",
|
||||||
|
tts_language="en-US",
|
||||||
|
tts_voice="james_earl_jones",
|
||||||
|
)
|
||||||
|
@ -92,6 +92,7 @@ async def test_select_entity_changing_pipelines(
|
|||||||
assert state.state == "preferred"
|
assert state.state == "preferred"
|
||||||
assert state.attributes["options"] == [
|
assert state.attributes["options"] == [
|
||||||
"preferred",
|
"preferred",
|
||||||
|
"Home Assistant",
|
||||||
pipeline_1.name,
|
pipeline_1.name,
|
||||||
pipeline_2.name,
|
pipeline_2.name,
|
||||||
]
|
]
|
||||||
@ -122,4 +123,8 @@ async def test_select_entity_changing_pipelines(
|
|||||||
|
|
||||||
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
state = hass.states.get("select.assist_pipeline_test_pipeline")
|
||||||
assert state.state == "preferred"
|
assert state.state == "preferred"
|
||||||
assert state.attributes["options"] == ["preferred", pipeline_1.name]
|
assert state.attributes["options"] == [
|
||||||
|
"preferred",
|
||||||
|
"Home Assistant",
|
||||||
|
pipeline_1.name,
|
||||||
|
]
|
||||||
|
@ -610,7 +610,7 @@ async def test_add_pipeline(
|
|||||||
"tts_voice": "Arnold Schwarzenegger",
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert len(pipeline_store.data) == 1
|
assert len(pipeline_store.data) == 2
|
||||||
pipeline = pipeline_store.data[msg["result"]["id"]]
|
pipeline = pipeline_store.data[msg["result"]["id"]]
|
||||||
assert pipeline == Pipeline(
|
assert pipeline == Pipeline(
|
||||||
conversation_engine="test_conversation_engine",
|
conversation_engine="test_conversation_engine",
|
||||||
@ -643,6 +643,7 @@ async def test_add_pipeline_missing_language(
|
|||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
assert len(pipeline_store.data) == 1
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
@ -660,7 +661,7 @@ async def test_add_pipeline_missing_language(
|
|||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert not msg["success"]
|
assert not msg["success"]
|
||||||
assert len(pipeline_store.data) == 0
|
assert len(pipeline_store.data) == 1
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
@ -678,7 +679,7 @@ async def test_add_pipeline_missing_language(
|
|||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert not msg["success"]
|
assert not msg["success"]
|
||||||
assert len(pipeline_store.data) == 0
|
assert len(pipeline_store.data) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_delete_pipeline(
|
async def test_delete_pipeline(
|
||||||
@ -725,7 +726,16 @@ async def test_delete_pipeline(
|
|||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
pipeline_id_2 = msg["result"]["id"]
|
pipeline_id_2 = msg["result"]["id"]
|
||||||
|
|
||||||
assert len(pipeline_store.data) == 2
|
assert len(pipeline_store.data) == 3
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/set_preferred",
|
||||||
|
"pipeline_id": pipeline_id_1,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
@ -748,7 +758,7 @@ async def test_delete_pipeline(
|
|||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
assert len(pipeline_store.data) == 1
|
assert len(pipeline_store.data) == 2
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
@ -778,10 +788,18 @@ async def test_get_pipeline(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert not msg["success"]
|
assert msg["success"]
|
||||||
assert msg["error"] == {
|
assert msg["result"] == {
|
||||||
"code": "not_found",
|
"conversation_engine": "homeassistant",
|
||||||
"message": "Unable to find pipeline_id None",
|
"conversation_language": "en",
|
||||||
|
"id": ANY,
|
||||||
|
"language": "en",
|
||||||
|
"name": "Home Assistant",
|
||||||
|
"stt_engine": "test",
|
||||||
|
"stt_language": "en-US",
|
||||||
|
"tts_engine": "test",
|
||||||
|
"tts_language": "en-US",
|
||||||
|
"tts_voice": "james_earl_jones",
|
||||||
}
|
}
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
@ -814,27 +832,7 @@ async def test_get_pipeline(
|
|||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
pipeline_id = msg["result"]["id"]
|
pipeline_id = msg["result"]["id"]
|
||||||
assert len(pipeline_store.data) == 1
|
assert len(pipeline_store.data) == 2
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
|
||||||
{
|
|
||||||
"type": "assist_pipeline/pipeline/get",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
assert msg["result"] == {
|
|
||||||
"conversation_engine": "test_conversation_engine",
|
|
||||||
"conversation_language": "test_language",
|
|
||||||
"id": pipeline_id,
|
|
||||||
"language": "test_language",
|
|
||||||
"name": "test_name",
|
|
||||||
"stt_engine": "test_stt_engine",
|
|
||||||
"stt_language": "test_language",
|
|
||||||
"tts_engine": "test_tts_engine",
|
|
||||||
"tts_language": "test_language",
|
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
|
||||||
}
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
@ -863,31 +861,7 @@ async def test_list_pipelines(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test we can list pipelines."""
|
"""Test we can list pipelines."""
|
||||||
client = await hass_ws_client(hass)
|
client = await hass_ws_client(hass)
|
||||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
hass.data[DOMAIN]
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
|
||||||
|
|
||||||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
assert msg["result"] == {"pipelines": [], "preferred_pipeline": None}
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
|
||||||
{
|
|
||||||
"type": "assist_pipeline/pipeline/create",
|
|
||||||
"conversation_engine": "test_conversation_engine",
|
|
||||||
"conversation_language": "test_language",
|
|
||||||
"language": "test_language",
|
|
||||||
"name": "test_name",
|
|
||||||
"stt_engine": "test_stt_engine",
|
|
||||||
"stt_language": "test_language",
|
|
||||||
"tts_engine": "test_tts_engine",
|
|
||||||
"tts_language": "test_language",
|
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
assert len(pipeline_store.data) == 1
|
|
||||||
|
|
||||||
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
@ -895,16 +869,16 @@ async def test_list_pipelines(
|
|||||||
assert msg["result"] == {
|
assert msg["result"] == {
|
||||||
"pipelines": [
|
"pipelines": [
|
||||||
{
|
{
|
||||||
"conversation_engine": "test_conversation_engine",
|
"conversation_engine": "homeassistant",
|
||||||
"conversation_language": "test_language",
|
"conversation_language": "en",
|
||||||
"id": ANY,
|
"id": ANY,
|
||||||
"language": "test_language",
|
"language": "en",
|
||||||
"name": "test_name",
|
"name": "Home Assistant",
|
||||||
"stt_engine": "test_stt_engine",
|
"stt_engine": "test",
|
||||||
"stt_language": "test_language",
|
"stt_language": "en-US",
|
||||||
"tts_engine": "test_tts_engine",
|
"tts_engine": "test",
|
||||||
"tts_language": "test_language",
|
"tts_language": "en-US",
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
"tts_voice": "james_earl_jones",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"preferred_pipeline": ANY,
|
"preferred_pipeline": ANY,
|
||||||
@ -958,7 +932,7 @@ async def test_update_pipeline(
|
|||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
pipeline_id = msg["result"]["id"]
|
pipeline_id = msg["result"]["id"]
|
||||||
assert len(pipeline_store.data) == 1
|
assert len(pipeline_store.data) == 2
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
@ -990,7 +964,7 @@ async def test_update_pipeline(
|
|||||||
"tts_voice": "new_tts_voice",
|
"tts_voice": "new_tts_voice",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert len(pipeline_store.data) == 1
|
assert len(pipeline_store.data) == 2
|
||||||
pipeline = pipeline_store.data[pipeline_id]
|
pipeline = pipeline_store.data[pipeline_id]
|
||||||
assert pipeline == Pipeline(
|
assert pipeline == Pipeline(
|
||||||
conversation_engine="new_conversation_engine",
|
conversation_engine="new_conversation_engine",
|
||||||
@ -1076,36 +1050,18 @@ async def test_set_preferred_pipeline(
|
|||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
pipeline_id_1 = msg["result"]["id"]
|
pipeline_id_1 = msg["result"]["id"]
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
assert pipeline_store.async_get_preferred_item() != pipeline_id_1
|
||||||
{
|
|
||||||
"type": "assist_pipeline/pipeline/create",
|
|
||||||
"conversation_engine": "test_conversation_engine",
|
|
||||||
"conversation_language": "test_language",
|
|
||||||
"language": "test_language",
|
|
||||||
"name": "test_name",
|
|
||||||
"stt_engine": "test_stt_engine",
|
|
||||||
"stt_language": "test_language",
|
|
||||||
"tts_engine": "test_tts_engine",
|
|
||||||
"tts_language": "test_language",
|
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
pipeline_id_2 = msg["result"]["id"]
|
|
||||||
|
|
||||||
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
await client.send_json_auto_id(
|
||||||
{
|
{
|
||||||
"type": "assist_pipeline/pipeline/set_preferred",
|
"type": "assist_pipeline/pipeline/set_preferred",
|
||||||
"pipeline_id": pipeline_id_2,
|
"pipeline_id": pipeline_id_1,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
msg = await client.receive_json()
|
msg = await client.receive_json()
|
||||||
assert msg["success"]
|
assert msg["success"]
|
||||||
|
|
||||||
assert pipeline_store.async_get_preferred_item() == pipeline_id_2
|
assert pipeline_store.async_get_preferred_item() == pipeline_id_1
|
||||||
|
|
||||||
|
|
||||||
async def test_set_preferred_pipeline_wrong_id(
|
async def test_set_preferred_pipeline_wrong_id(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user