Remove cloud assist pipeline setup from cloud client (#92056)

This commit is contained in:
Erik Montnemery 2023-04-26 12:53:58 +02:00 committed by GitHub
parent 6b931b208f
commit ed737f306b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 119 additions and 133 deletions

View File

@ -238,25 +238,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
await prefs.async_initialize() await prefs.async_initialize()
# Initialize Cloud # Initialize Cloud
loaded = False
async def _discover_platforms():
"""Discover platforms."""
nonlocal loaded
# Prevent multiple discovery
if loaded:
return
loaded = True
await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
await async_load_platform(hass, Platform.STT, DOMAIN, {}, config)
await async_load_platform(hass, Platform.TTS, DOMAIN, {}, config)
websession = async_get_clientsession(hass) websession = async_get_clientsession(hass)
client = CloudClient( client = CloudClient(hass, prefs, websession, alexa_conf, google_conf)
hass, prefs, websession, alexa_conf, google_conf, _discover_platforms
)
cloud = hass.data[DOMAIN] = Cloud(client, **kwargs) cloud = hass.data[DOMAIN] = Cloud(client, **kwargs)
cloud.iot.register_on_connect(client.on_cloud_connected) cloud.iot.register_on_connect(client.on_cloud_connected)
@ -288,6 +271,27 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if subscription_info := await async_subscription_info(cloud): if subscription_info := await async_subscription_info(cloud):
async_manage_legacy_subscription_issue(hass, subscription_info) async_manage_legacy_subscription_issue(hass, subscription_info)
loaded = False
async def _on_start():
"""Discover platforms."""
nonlocal loaded
# Prevent multiple discovery
if loaded:
return
loaded = True
stt_platform_loaded = asyncio.Event()
tts_platform_loaded = asyncio.Event()
stt_info = {"platform_loaded": stt_platform_loaded}
tts_info = {"platform_loaded": tts_platform_loaded}
await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config)
await async_load_platform(hass, Platform.STT, DOMAIN, stt_info, config)
await async_load_platform(hass, Platform.TTS, DOMAIN, tts_info, config)
await asyncio.gather(stt_platform_loaded.wait(), tts_platform_loaded.wait())
async def _on_connect(): async def _on_connect():
"""Handle cloud connect.""" """Handle cloud connect."""
async_dispatcher_send( async_dispatcher_send(
@ -304,6 +308,7 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Update preferences.""" """Update preferences."""
await prefs.async_update(remote_domain=cloud.remote.instance_domain) await prefs.async_update(remote_domain=cloud.remote.instance_domain)
cloud.register_on_start(_on_start)
cloud.iot.register_on_connect(_on_connect) cloud.iot.register_on_connect(_on_connect)
cloud.iot.register_on_disconnect(_on_disconnect) cloud.iot.register_on_disconnect(_on_disconnect)
cloud.register_on_initialized(_on_initialized) cloud.register_on_initialized(_on_initialized)

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable, Coroutine
from http import HTTPStatus from http import HTTPStatus
import logging import logging
from pathlib import Path from pathlib import Path
@ -11,13 +10,7 @@ from typing import Any
import aiohttp import aiohttp
from hass_nabucasa.client import CloudClient as Interface from hass_nabucasa.client import CloudClient as Interface
from homeassistant.components import ( from homeassistant.components import google_assistant, persistent_notification, webhook
assist_pipeline,
conversation,
google_assistant,
persistent_notification,
webhook,
)
from homeassistant.components.alexa import ( from homeassistant.components.alexa import (
errors as alexa_errors, errors as alexa_errors,
smart_home as alexa_smart_home, smart_home as alexa_smart_home,
@ -43,7 +36,6 @@ class CloudClient(Interface):
websession: aiohttp.ClientSession, websession: aiohttp.ClientSession,
alexa_user_config: dict[str, Any], alexa_user_config: dict[str, Any],
google_user_config: dict[str, Any], google_user_config: dict[str, Any],
on_started_cb: Callable[[], Coroutine[Any, Any, None]],
) -> None: ) -> None:
"""Initialize client interface to Cloud.""" """Initialize client interface to Cloud."""
self._hass = hass self._hass = hass
@ -56,10 +48,6 @@ class CloudClient(Interface):
self._alexa_config_init_lock = asyncio.Lock() self._alexa_config_init_lock = asyncio.Lock()
self._google_config_init_lock = asyncio.Lock() self._google_config_init_lock = asyncio.Lock()
self._relayer_region: str | None = None self._relayer_region: str | None = None
self._on_started_cb = on_started_cb
self.cloud_pipeline = self._cloud_assist_pipeline()
self.stt_platform_loaded = asyncio.Event()
self.tts_platform_loaded = asyncio.Event()
@property @property
def base_path(self) -> Path: def base_path(self) -> Path:
@ -148,22 +136,6 @@ class CloudClient(Interface):
return self._google_config return self._google_config
def _cloud_assist_pipeline(self) -> str | None:
"""Return the ID of a cloud-enabled assist pipeline or None."""
for pipeline in assist_pipeline.async_get_pipelines(self._hass):
if (
pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT
and pipeline.stt_engine == DOMAIN
and pipeline.tts_engine == DOMAIN
):
return pipeline.id
return None
async def create_cloud_assist_pipeline(self) -> None:
"""Create a cloud-enabled assist pipeline."""
await assist_pipeline.async_create_default_pipeline(self._hass, DOMAIN, DOMAIN)
self.cloud_pipeline = self._cloud_assist_pipeline()
async def on_cloud_connected(self) -> None: async def on_cloud_connected(self) -> None:
"""When cloud is connected.""" """When cloud is connected."""
is_new_user = await self.prefs.async_set_username(self.cloud.username) is_new_user = await self.prefs.async_set_username(self.cloud.username)
@ -211,11 +183,6 @@ class CloudClient(Interface):
async def cloud_started(self) -> None: async def cloud_started(self) -> None:
"""When cloud is started.""" """When cloud is started."""
await self._on_started_cb()
await asyncio.gather(
self.stt_platform_loaded.wait(),
self.tts_platform_loaded.wait(),
)
async def cloud_stopped(self) -> None: async def cloud_stopped(self) -> None:
"""When the cloud is stopped.""" """When the cloud is stopped."""

View File

@ -15,7 +15,7 @@ from hass_nabucasa.const import STATE_DISCONNECTED
from hass_nabucasa.voice import MAP_VOICE from hass_nabucasa.voice import MAP_VOICE
import voluptuous as vol import voluptuous as vol
from homeassistant.components import websocket_api from homeassistant.components import assist_pipeline, conversation, websocket_api
from homeassistant.components.alexa import ( from homeassistant.components.alexa import (
entities as alexa_entities, entities as alexa_entities,
errors as alexa_errors, errors as alexa_errors,
@ -182,15 +182,28 @@ class CloudLoginView(HomeAssistantView):
) )
async def post(self, request, data): async def post(self, request, data):
"""Handle login request.""" """Handle login request."""
def cloud_assist_pipeline(hass: HomeAssistant) -> str | None:
"""Return the ID of a cloud-enabled assist pipeline or None."""
for pipeline in assist_pipeline.async_get_pipelines(hass):
if (
pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT
and pipeline.stt_engine == DOMAIN
and pipeline.tts_engine == DOMAIN
):
return pipeline.id
return None
hass = request.app["hass"] hass = request.app["hass"]
cloud = hass.data[DOMAIN] cloud = hass.data[DOMAIN]
await cloud.login(data["email"], data["password"]) await cloud.login(data["email"], data["password"])
if cloud.client.cloud_pipeline is None: if (cloud_pipeline_id := cloud_assist_pipeline(hass)) is None:
await cloud.client.create_cloud_assist_pipeline() if cloud_pipeline := await assist_pipeline.async_create_default_pipeline(
return self.json( hass, DOMAIN, DOMAIN
{"success": True, "cloud_pipeline": cloud.client.cloud_pipeline} ):
) cloud_pipeline_id = cloud_pipeline.id
return self.json({"success": True, "cloud_pipeline": cloud_pipeline_id})
class CloudLogoutView(HomeAssistantView): class CloudLogoutView(HomeAssistantView):

View File

@ -29,7 +29,8 @@ async def async_get_engine(hass, config, discovery_info=None):
cloud: Cloud = hass.data[DOMAIN] cloud: Cloud = hass.data[DOMAIN]
cloud_provider = CloudProvider(cloud) cloud_provider = CloudProvider(cloud)
cloud.client.stt_platform_loaded.set() if discovery_info is not None:
discovery_info["platform_loaded"].set()
return cloud_provider return cloud_provider

View File

@ -64,7 +64,8 @@ async def async_get_engine(hass, config, discovery_info=None):
gender = config[ATTR_GENDER] gender = config[ATTR_GENDER]
cloud_provider = CloudProvider(cloud, language, gender) cloud_provider = CloudProvider(cloud, language, gender)
cloud.client.tts_platform_loaded.set() if discovery_info is not None:
discovery_info["platform_loaded"].set()
return cloud_provider return cloud_provider

View File

@ -6,11 +6,6 @@ import aiohttp
from aiohttp import web from aiohttp import web
import pytest import pytest
from homeassistant.components.assist_pipeline import (
Pipeline,
async_get_pipeline,
async_get_pipelines,
)
from homeassistant.components.cloud import DOMAIN from homeassistant.components.cloud import DOMAIN
from homeassistant.components.cloud.client import CloudClient from homeassistant.components.cloud.client import CloudClient
from homeassistant.components.cloud.const import ( from homeassistant.components.cloud.const import (
@ -303,18 +298,14 @@ async def test_google_config_should_2fa(
assert not gconf.should_2fa(state) assert not gconf.should_2fa(state)
@patch( async def test_set_username(hass: HomeAssistant) -> None:
"homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines",
return_value=[],
)
async def test_set_username(async_get_pipelines, hass: HomeAssistant) -> None:
"""Test we set username during login.""" """Test we set username during login."""
prefs = MagicMock( prefs = MagicMock(
alexa_enabled=False, alexa_enabled=False,
google_enabled=False, google_enabled=False,
async_set_username=AsyncMock(return_value=None), async_set_username=AsyncMock(return_value=None),
) )
client = CloudClient(hass, prefs, None, {}, {}, AsyncMock()) client = CloudClient(hass, prefs, None, {}, {})
client.cloud = MagicMock(is_logged_in=True, username="mock-username") client.cloud = MagicMock(is_logged_in=True, username="mock-username")
await client.on_cloud_connected() await client.on_cloud_connected()
@ -322,12 +313,8 @@ async def test_set_username(async_get_pipelines, hass: HomeAssistant) -> None:
assert prefs.async_set_username.mock_calls[0][1][0] == "mock-username" assert prefs.async_set_username.mock_calls[0][1][0] == "mock-username"
@patch(
"homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines",
return_value=[],
)
async def test_login_recovers_bad_internet( async def test_login_recovers_bad_internet(
async_get_pipelines, hass: HomeAssistant, caplog: pytest.LogCaptureFixture hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None: ) -> None:
"""Test Alexa can recover bad auth.""" """Test Alexa can recover bad auth."""
prefs = Mock( prefs = Mock(
@ -335,7 +322,7 @@ async def test_login_recovers_bad_internet(
google_enabled=False, google_enabled=False,
async_set_username=AsyncMock(return_value=None), async_set_username=AsyncMock(return_value=None),
) )
client = CloudClient(hass, prefs, None, {}, {}, AsyncMock()) client = CloudClient(hass, prefs, None, {}, {})
client.cloud = Mock() client.cloud = Mock()
client._alexa_config = Mock( client._alexa_config = Mock(
async_enable_proactive_mode=Mock(side_effect=aiohttp.ClientError) async_enable_proactive_mode=Mock(side_effect=aiohttp.ClientError)
@ -367,29 +354,3 @@ async def test_system_msg(hass: HomeAssistant) -> None:
assert response is None assert response is None
assert cloud.client.relayer_region == "xx-earth-616" assert cloud.client.relayer_region == "xx-earth-616"
async def test_create_cloud_assist_pipeline(
hass: HomeAssistant, mock_cloud_setup, mock_cloud_login
) -> None:
"""Test creating a cloud enabled assist pipeline."""
cloud_client: CloudClient = hass.data[DOMAIN].client
await cloud_client.cloud_started()
assert cloud_client.cloud_pipeline is None
assert len(async_get_pipelines(hass)) == 1
await cloud_client.create_cloud_assist_pipeline()
assert cloud_client.cloud_pipeline is not None
assert len(async_get_pipelines(hass)) == 2
assert async_get_pipeline(hass, cloud_client.cloud_pipeline) == Pipeline(
conversation_engine="homeassistant",
conversation_language="en",
id=cloud_client.cloud_pipeline,
language="en",
name="Home Assistant Cloud",
stt_engine="cloud",
stt_language="en-US",
tts_engine="cloud",
tts_language="en-US",
tts_voice="JennyNeural",
)

View File

@ -104,16 +104,22 @@ async def test_google_actions_sync_fails(
async def test_login_view(hass: HomeAssistant, cloud_client) -> None: async def test_login_view(hass: HomeAssistant, cloud_client) -> None:
"""Test logging in.""" """Test logging in when an assist pipeline is available."""
create_cloud_assist_pipeline_mock = AsyncMock() hass.data["cloud"] = MagicMock(login=AsyncMock())
hass.data["cloud"] = MagicMock(
login=AsyncMock(),
client=Mock(
cloud_pipeline="12345",
create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock,
),
)
with patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_get_pipelines",
return_value=[
Mock(
conversation_engine="homeassistant",
id="12345",
stt_engine=DOMAIN,
tts_engine=DOMAIN,
)
],
), patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
) as create_pipeline_mock:
req = await cloud_client.post( req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"} "/api/cloud/login", json={"email": "my_username", "password": "my_password"}
) )
@ -121,20 +127,37 @@ async def test_login_view(hass: HomeAssistant, cloud_client) -> None:
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
result = await req.json() result = await req.json()
assert result == {"success": True, "cloud_pipeline": "12345"} assert result == {"success": True, "cloud_pipeline": "12345"}
create_cloud_assist_pipeline_mock.assert_not_awaited() create_pipeline_mock.assert_not_awaited()
async def test_login_view_create_pipeline(hass: HomeAssistant, cloud_client) -> None: async def test_login_view_create_pipeline(hass: HomeAssistant, cloud_client) -> None:
"""Test logging in when no assist pipeline is available.""" """Test logging in when no assist pipeline is available."""
create_cloud_assist_pipeline_mock = AsyncMock() hass.data["cloud"] = MagicMock(login=AsyncMock())
hass.data["cloud"] = MagicMock(
login=AsyncMock(), with patch(
client=Mock( "homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
cloud_pipeline=None, return_value=AsyncMock(id="12345"),
create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock, ) as create_pipeline_mock:
), req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"}
) )
assert req.status == HTTPStatus.OK
result = await req.json()
assert result == {"success": True, "cloud_pipeline": "12345"}
create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud")
async def test_login_view_create_pipeline_fail(
hass: HomeAssistant, cloud_client
) -> None:
"""Test logging in when no assist pipeline is available."""
hass.data["cloud"] = MagicMock(login=AsyncMock())
with patch(
"homeassistant.components.cloud.http_api.assist_pipeline.async_create_default_pipeline",
return_value=None,
) as create_pipeline_mock:
req = await cloud_client.post( req = await cloud_client.post(
"/api/cloud/login", json={"email": "my_username", "password": "my_password"} "/api/cloud/login", json={"email": "my_username", "password": "my_password"}
) )
@ -142,7 +165,7 @@ async def test_login_view_create_pipeline(hass: HomeAssistant, cloud_client) ->
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
result = await req.json() result = await req.json()
assert result == {"success": True, "cloud_pipeline": None} assert result == {"success": True, "cloud_pipeline": None}
create_cloud_assist_pipeline_mock.assert_awaited_once() create_pipeline_mock.assert_awaited_once_with(hass, "cloud", "cloud")
async def test_login_view_random_exception(cloud_client) -> None: async def test_login_view_random_exception(cloud_client) -> None:

View File

@ -155,17 +155,24 @@ async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None:
assert len(hass.states.async_entity_ids("binary_sensor")) == 0 assert len(hass.states.async_entity_ids("binary_sensor")) == 0
await cl.client.cloud_started() # The on_start callback discovers the binary sensor platform
assert "async_setup" in str(cl._on_start[-1])
await cl._on_start[-1]()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(hass.states.async_entity_ids("binary_sensor")) == 1 assert len(hass.states.async_entity_ids("binary_sensor")) == 1
with patch("homeassistant.helpers.discovery.async_load_platform") as mock_load: with patch("homeassistant.helpers.discovery.async_load_platform") as mock_load:
await cl.iot._on_connect[-1]() await cl._on_start[-1]()
await hass.async_block_till_done() await hass.async_block_till_done()
assert len(mock_load.mock_calls) == 0 assert len(mock_load.mock_calls) == 0
assert len(cloud_states) == 1
assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_CONNECTED
await cl.iot._on_connect[-1]()
await hass.async_block_till_done()
assert len(cloud_states) == 2 assert len(cloud_states) == 2
assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_CONNECTED assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_CONNECTED
@ -177,6 +184,11 @@ async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None:
assert len(cloud_states) == 3 assert len(cloud_states) == 3
assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_DISCONNECTED assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_DISCONNECTED
await cl.iot._on_disconnect[-1]()
await hass.async_block_till_done()
assert len(cloud_states) == 4
assert cloud_states[-1] == cloud.CloudConnectionState.CLOUD_DISCONNECTED
async def test_remote_ui_url(hass: HomeAssistant, mock_cloud_fixture) -> None: async def test_remote_ui_url(hass: HomeAssistant, mock_cloud_fixture) -> None:
"""Test getting remote ui url.""" """Test getting remote ui url."""

View File

@ -48,8 +48,9 @@ async def test_prefs_default_voice(
"""Test cloud provider uses the preferences.""" """Test cloud provider uses the preferences."""
assert cloud_prefs.tts_default_voice == ("en-US", "female") assert cloud_prefs.tts_default_voice == ("en-US", "female")
tts_info = {"platform_loaded": Mock()}
provider_pref = await tts.async_get_engine( provider_pref = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} Mock(data={const.DOMAIN: cloud_with_prefs}), None, tts_info
) )
provider_conf = await tts.async_get_engine( provider_conf = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), Mock(data={const.DOMAIN: cloud_with_prefs}),
@ -73,8 +74,9 @@ async def test_prefs_default_voice(
async def test_provider_properties(cloud_with_prefs) -> None: async def test_provider_properties(cloud_with_prefs) -> None:
"""Test cloud provider.""" """Test cloud provider."""
tts_info = {"platform_loaded": Mock()}
provider = await tts.async_get_engine( provider = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} Mock(data={const.DOMAIN: cloud_with_prefs}), None, tts_info
) )
assert provider.supported_options == ["gender", "voice", "audio_output"] assert provider.supported_options == ["gender", "voice", "audio_output"]
assert "nl-NL" in provider.supported_languages assert "nl-NL" in provider.supported_languages
@ -85,8 +87,9 @@ async def test_provider_properties(cloud_with_prefs) -> None:
async def test_get_tts_audio(cloud_with_prefs) -> None: async def test_get_tts_audio(cloud_with_prefs) -> None:
"""Test cloud provider.""" """Test cloud provider."""
tts_info = {"platform_loaded": Mock()}
provider = await tts.async_get_engine( provider = await tts.async_get_engine(
Mock(data={const.DOMAIN: cloud_with_prefs}), None, {} Mock(data={const.DOMAIN: cloud_with_prefs}), None, tts_info
) )
assert provider.supported_options == ["gender", "voice", "audio_output"] assert provider.supported_options == ["gender", "voice", "audio_output"]
assert "nl-NL" in provider.supported_languages assert "nl-NL" in provider.supported_languages