mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Remove cloud details from assist pipeline (#105687)
* Remove cloud details from assist pipeline * Update assist pipeline tests * Update cloud tests
This commit is contained in:
parent
82f0b28e89
commit
2e448d2d13
@ -115,6 +115,7 @@ async def _async_resolve_default_pipeline_settings(
|
||||
hass: HomeAssistant,
|
||||
stt_engine_id: str | None,
|
||||
tts_engine_id: str | None,
|
||||
pipeline_name: str,
|
||||
) -> dict[str, str | None]:
|
||||
"""Resolve settings for a default pipeline.
|
||||
|
||||
@ -123,7 +124,6 @@ async def _async_resolve_default_pipeline_settings(
|
||||
"""
|
||||
conversation_language = "en"
|
||||
pipeline_language = "en"
|
||||
pipeline_name = "Home Assistant"
|
||||
stt_engine = None
|
||||
stt_language = None
|
||||
tts_engine = None
|
||||
@ -195,9 +195,6 @@ async def _async_resolve_default_pipeline_settings(
|
||||
)
|
||||
tts_engine_id = None
|
||||
|
||||
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
|
||||
pipeline_name = "Home Assistant Cloud"
|
||||
|
||||
return {
|
||||
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
|
||||
"conversation_language": conversation_language,
|
||||
@ -221,12 +218,17 @@ async def _async_create_default_pipeline(
|
||||
The default pipeline will use the homeassistant conversation agent and the
|
||||
default stt / tts engines.
|
||||
"""
|
||||
pipeline_settings = await _async_resolve_default_pipeline_settings(hass, None, None)
|
||||
pipeline_settings = await _async_resolve_default_pipeline_settings(
|
||||
hass, stt_engine_id=None, tts_engine_id=None, pipeline_name="Home Assistant"
|
||||
)
|
||||
return await pipeline_store.async_create_item(pipeline_settings)
|
||||
|
||||
|
||||
async def async_create_default_pipeline(
|
||||
hass: HomeAssistant, stt_engine_id: str, tts_engine_id: str
|
||||
hass: HomeAssistant,
|
||||
stt_engine_id: str,
|
||||
tts_engine_id: str,
|
||||
pipeline_name: str,
|
||||
) -> Pipeline | None:
|
||||
"""Create a pipeline with default settings.
|
||||
|
||||
@ -236,7 +238,7 @@ async def async_create_default_pipeline(
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_settings = await _async_resolve_default_pipeline_settings(
|
||||
hass, stt_engine_id, tts_engine_id
|
||||
hass, stt_engine_id, tts_engine_id, pipeline_name=pipeline_name
|
||||
)
|
||||
if (
|
||||
pipeline_settings["stt_engine"] != stt_engine_id
|
||||
|
@ -232,7 +232,10 @@ class CloudLoginView(HomeAssistantView):
|
||||
new_cloud_pipeline_id: str | None = None
|
||||
if (cloud_assist_pipeline(hass)) is None:
|
||||
if cloud_pipeline := await assist_pipeline.async_create_default_pipeline(
|
||||
hass, DOMAIN, DOMAIN
|
||||
hass,
|
||||
stt_engine_id=DOMAIN,
|
||||
tts_engine_id=DOMAIN,
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
):
|
||||
new_cloud_pipeline_id = cloud_pipeline.id
|
||||
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Websocket tests for Voice Assistant integration."""
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, AsyncMock, patch
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -21,9 +21,9 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import MANY_LANGUAGES
|
||||
from .conftest import MockSttPlatform, MockSttProvider, MockTTSPlatform, MockTTSProvider
|
||||
from .conftest import MockSttProvider, MockTTSProvider
|
||||
|
||||
from tests.common import MockModule, flush_store, mock_integration, mock_platform
|
||||
from tests.common import flush_store
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -237,13 +237,26 @@ async def test_create_default_pipeline(
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 1
|
||||
|
||||
assert await async_create_default_pipeline(hass, "bla", "bla") is None
|
||||
assert await async_create_default_pipeline(hass, "test", "test") == Pipeline(
|
||||
assert (
|
||||
await async_create_default_pipeline(
|
||||
hass,
|
||||
stt_engine_id="bla",
|
||||
tts_engine_id="bla",
|
||||
pipeline_name="Bla pipeline",
|
||||
)
|
||||
is None
|
||||
)
|
||||
assert await async_create_default_pipeline(
|
||||
hass,
|
||||
stt_engine_id="test",
|
||||
tts_engine_id="test",
|
||||
pipeline_name="Test pipeline",
|
||||
) == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_language="en",
|
||||
id=ANY,
|
||||
language="en",
|
||||
name="Home Assistant",
|
||||
name="Test pipeline",
|
||||
stt_engine="test",
|
||||
stt_language="en-US",
|
||||
tts_engine="test",
|
||||
@ -465,53 +478,3 @@ async def test_default_pipeline_unsupported_tts_language(
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
||||
|
||||
async def test_default_pipeline_cloud(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
) -> None:
|
||||
"""Test async_get_pipeline."""
|
||||
|
||||
mock_integration(hass, MockModule("cloud"))
|
||||
mock_platform(
|
||||
hass,
|
||||
"cloud.tts",
|
||||
MockTTSPlatform(
|
||||
async_get_engine=AsyncMock(return_value=mock_tts_provider),
|
||||
),
|
||||
)
|
||||
mock_platform(
|
||||
hass,
|
||||
"cloud.stt",
|
||||
MockSttPlatform(
|
||||
async_get_engine=AsyncMock(return_value=mock_stt_provider),
|
||||
),
|
||||
)
|
||||
mock_platform(hass, "test.config_flow")
|
||||
|
||||
assert await async_setup_component(hass, "tts", {"tts": {"platform": "cloud"}})
|
||||
assert await async_setup_component(hass, "stt", {"stt": {"platform": "cloud"}})
|
||||
assert await async_setup_component(hass, "assist_pipeline", {})
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
store = pipeline_data.pipeline_store
|
||||
assert len(store.data) == 1
|
||||
|
||||
# Check the default pipeline
|
||||
pipeline = async_get_pipeline(hass, None)
|
||||
assert pipeline == Pipeline(
|
||||
conversation_engine="homeassistant",
|
||||
conversation_language="en",
|
||||
id=pipeline.id,
|
||||
language="en",
|
||||
name="Home Assistant Cloud",
|
||||
stt_engine="cloud",
|
||||
stt_language="en-US",
|
||||
tts_engine="cloud",
|
||||
tts_language="en-US",
|
||||
tts_voice="james_earl_jones",
|
||||
wake_word_entity=None,
|
||||
wake_word_id=None,
|
||||
)
|
||||
|
@ -193,7 +193,12 @@ async def test_login_view_create_pipeline(
|
||||
assert req.status == HTTPStatus.OK
|
||||
result = await req.json()
|
||||
assert result == {"success": True, "cloud_pipeline": "12345"}
|
||||
create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud")
|
||||
create_pipeline_mock.assert_awaited_once_with(
|
||||
hass,
|
||||
stt_engine_id="cloud",
|
||||
tts_engine_id="cloud",
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
)
|
||||
|
||||
|
||||
async def test_login_view_create_pipeline_fail(
|
||||
@ -227,7 +232,12 @@ async def test_login_view_create_pipeline_fail(
|
||||
assert req.status == HTTPStatus.OK
|
||||
result = await req.json()
|
||||
assert result == {"success": True, "cloud_pipeline": None}
|
||||
create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud")
|
||||
create_pipeline_mock.assert_awaited_once_with(
|
||||
hass,
|
||||
stt_engine_id="cloud",
|
||||
tts_engine_id="cloud",
|
||||
pipeline_name="Home Assistant Cloud",
|
||||
)
|
||||
|
||||
|
||||
async def test_login_view_random_exception(
|
||||
|
Loading…
x
Reference in New Issue
Block a user