Remove cloud details from assist pipeline (#105687)

* Remove cloud details from assist pipeline

* Update assist pipeline tests

* Update cloud tests
This commit is contained in:
Martin Hjelmare 2023-12-14 10:15:59 +01:00 committed by GitHub
parent 82f0b28e89
commit 2e448d2d13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 44 additions and 66 deletions

View File

@ -115,6 +115,7 @@ async def _async_resolve_default_pipeline_settings(
hass: HomeAssistant,
stt_engine_id: str | None,
tts_engine_id: str | None,
pipeline_name: str,
) -> dict[str, str | None]:
"""Resolve settings for a default pipeline.
@ -123,7 +124,6 @@ async def _async_resolve_default_pipeline_settings(
"""
conversation_language = "en"
pipeline_language = "en"
pipeline_name = "Home Assistant"
stt_engine = None
stt_language = None
tts_engine = None
@ -195,9 +195,6 @@ async def _async_resolve_default_pipeline_settings(
)
tts_engine_id = None
if stt_engine_id == "cloud" and tts_engine_id == "cloud":
pipeline_name = "Home Assistant Cloud"
return {
"conversation_engine": conversation.HOME_ASSISTANT_AGENT,
"conversation_language": conversation_language,
@ -221,12 +218,17 @@ async def _async_create_default_pipeline(
The default pipeline will use the homeassistant conversation agent and the
default stt / tts engines.
"""
pipeline_settings = await _async_resolve_default_pipeline_settings(hass, None, None)
pipeline_settings = await _async_resolve_default_pipeline_settings(
hass, stt_engine_id=None, tts_engine_id=None, pipeline_name="Home Assistant"
)
return await pipeline_store.async_create_item(pipeline_settings)
async def async_create_default_pipeline(
hass: HomeAssistant, stt_engine_id: str, tts_engine_id: str
hass: HomeAssistant,
stt_engine_id: str,
tts_engine_id: str,
pipeline_name: str,
) -> Pipeline | None:
"""Create a pipeline with default settings.
@ -236,7 +238,7 @@ async def async_create_default_pipeline(
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_store = pipeline_data.pipeline_store
pipeline_settings = await _async_resolve_default_pipeline_settings(
hass, stt_engine_id, tts_engine_id
hass, stt_engine_id, tts_engine_id, pipeline_name=pipeline_name
)
if (
pipeline_settings["stt_engine"] != stt_engine_id

View File

@ -232,7 +232,10 @@ class CloudLoginView(HomeAssistantView):
new_cloud_pipeline_id: str | None = None
if (cloud_assist_pipeline(hass)) is None:
if cloud_pipeline := await assist_pipeline.async_create_default_pipeline(
hass, DOMAIN, DOMAIN
hass,
stt_engine_id=DOMAIN,
tts_engine_id=DOMAIN,
pipeline_name="Home Assistant Cloud",
):
new_cloud_pipeline_id = cloud_pipeline.id
return self.json({"success": True, "cloud_pipeline": new_cloud_pipeline_id})

View File

@ -1,6 +1,6 @@
"""Websocket tests for Voice Assistant integration."""
from typing import Any
from unittest.mock import ANY, AsyncMock, patch
from unittest.mock import ANY, patch
import pytest
@ -21,9 +21,9 @@ from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES
from .conftest import MockSttPlatform, MockSttProvider, MockTTSPlatform, MockTTSProvider
from .conftest import MockSttProvider, MockTTSProvider
from tests.common import MockModule, flush_store, mock_integration, mock_platform
from tests.common import flush_store
@pytest.fixture(autouse=True)
@ -237,13 +237,26 @@ async def test_create_default_pipeline(
store = pipeline_data.pipeline_store
assert len(store.data) == 1
assert await async_create_default_pipeline(hass, "bla", "bla") is None
assert await async_create_default_pipeline(hass, "test", "test") == Pipeline(
assert (
await async_create_default_pipeline(
hass,
stt_engine_id="bla",
tts_engine_id="bla",
pipeline_name="Bla pipeline",
)
is None
)
assert await async_create_default_pipeline(
hass,
stt_engine_id="test",
tts_engine_id="test",
pipeline_name="Test pipeline",
) == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=ANY,
language="en",
name="Home Assistant",
name="Test pipeline",
stt_engine="test",
stt_language="en-US",
tts_engine="test",
@ -465,53 +478,3 @@ async def test_default_pipeline_unsupported_tts_language(
wake_word_entity=None,
wake_word_id=None,
)
async def test_default_pipeline_cloud(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_tts_provider: MockTTSProvider,
) -> None:
"""Test async_get_pipeline."""
mock_integration(hass, MockModule("cloud"))
mock_platform(
hass,
"cloud.tts",
MockTTSPlatform(
async_get_engine=AsyncMock(return_value=mock_tts_provider),
),
)
mock_platform(
hass,
"cloud.stt",
MockSttPlatform(
async_get_engine=AsyncMock(return_value=mock_stt_provider),
),
)
mock_platform(hass, "test.config_flow")
assert await async_setup_component(hass, "tts", {"tts": {"platform": "cloud"}})
assert await async_setup_component(hass, "stt", {"stt": {"platform": "cloud"}})
assert await async_setup_component(hass, "assist_pipeline", {})
pipeline_data: PipelineData = hass.data[DOMAIN]
store = pipeline_data.pipeline_store
assert len(store.data) == 1
# Check the default pipeline
pipeline = async_get_pipeline(hass, None)
assert pipeline == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=pipeline.id,
language="en",
name="Home Assistant Cloud",
stt_engine="cloud",
stt_language="en-US",
tts_engine="cloud",
tts_language="en-US",
tts_voice="james_earl_jones",
wake_word_entity=None,
wake_word_id=None,
)

View File

@ -193,7 +193,12 @@ async def test_login_view_create_pipeline(
assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": "12345"}
create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud")
create_pipeline_mock.assert_awaited_once_with(
hass,
stt_engine_id="cloud",
tts_engine_id="cloud",
pipeline_name="Home Assistant Cloud",
)
async def test_login_view_create_pipeline_fail(
@ -227,7 +232,12 @@ async def test_login_view_create_pipeline_fail(
assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": None}
create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud")
create_pipeline_mock.assert_awaited_once_with(
hass,
stt_engine_id="cloud",
tts_engine_id="cloud",
pipeline_name="Home Assistant Cloud",
)
async def test_login_view_random_exception(