From 57a59d808b0858081c675c901df10bb6e753d7f4 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 26 Apr 2023 03:40:01 +0200 Subject: [PATCH] Automaticially create an assist pipeline using cloud stt + tts (#91991) * Automaticially create an assist pipeline using cloud stt + tts * Return the id of the cloud enabled pipeline * Wait for platforms to load * Fix typing * Fix startup race * Update tests * Create a cloud pipeline only when logging in * Fix tests * Tweak _async_resolve_default_pipeline_settings * Improve assist_pipeline test coverage * Improve cloud test coverage --- .../components/assist_pipeline/__init__.py | 4 + .../components/assist_pipeline/pipeline.py | 111 +++++++++++++----- homeassistant/components/cloud/__init__.py | 37 +++--- homeassistant/components/cloud/client.py | 42 ++++++- homeassistant/components/cloud/http_api.py | 6 +- homeassistant/components/cloud/manifest.json | 2 +- homeassistant/components/cloud/stt.py | 4 +- homeassistant/components/cloud/tts.py | 4 +- .../assist_pipeline/test_pipeline.py | 54 ++++++++- tests/components/cloud/conftest.py | 7 ++ tests/components/cloud/test_client.py | 51 +++++++- tests/components/cloud/test_http_api.py | 33 +++++- tests/components/cloud/test_init.py | 10 +- tests/components/stt/test_init.py | 6 + 14 files changed, 303 insertions(+), 68 deletions(-) diff --git a/homeassistant/components/assist_pipeline/__init__.py b/homeassistant/components/assist_pipeline/__init__.py index a56e535cc63..7af379804e1 100644 --- a/homeassistant/components/assist_pipeline/__init__.py +++ b/homeassistant/components/assist_pipeline/__init__.py @@ -17,13 +17,17 @@ from .pipeline import ( PipelineInput, PipelineRun, PipelineStage, + async_create_default_pipeline, async_get_pipeline, + async_get_pipelines, async_setup_pipeline_store, ) from .websocket_api import async_register_websocket_api __all__ = ( "DOMAIN", + "async_create_default_pipeline", + "async_get_pipelines", "async_setup", "async_pipeline_from_audio_stream", "Pipeline", diff --git a/homeassistant/components/assist_pipeline/pipeline.py b/homeassistant/components/assist_pipeline/pipeline.py index cb00fe8852e..e223ff324a2 100644 --- a/homeassistant/components/assist_pipeline/pipeline.py +++ b/homeassistant/components/assist_pipeline/pipeline.py @@ -2,7 +2,7 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncIterable, Callable +from collections.abc import AsyncIterable, Callable, Iterable from dataclasses import asdict, dataclass, field import logging from typing import Any @@ -75,20 +75,22 @@ STORED_PIPELINE_RUNS = 10 SAVE_DELAY = 10 -async def _async_create_default_pipeline( - hass: HomeAssistant, pipeline_store: PipelineStorageCollection -) -> Pipeline: - """Create a default pipeline. +async def _async_resolve_default_pipeline_settings( + hass: HomeAssistant, + stt_engine_id: str | None, + tts_engine_id: str | None, +) -> dict[str, str | None]: + """Resolve settings for a default pipeline. The default pipeline will use the homeassistant conversation agent and the - default stt / tts engines. + default stt / tts engines if none are specified. """ conversation_language = "en" pipeline_language = "en" pipeline_name = "Home Assistant" - stt_engine_id = None + stt_engine = None stt_language = None - tts_engine_id = None + tts_engine = None tts_language = None tts_voice = None @@ -104,12 +106,15 @@ async def _async_create_default_pipeline( pipeline_language = hass.config.language conversation_language = conversation_languages[0] - if (stt_engine_id := stt.async_default_engine(hass)) is not None and ( - stt_engine := stt.async_get_speech_to_text_engine( - hass, - stt_engine_id, - ) - ): + if stt_engine_id is None: + stt_engine_id = stt.async_default_engine(hass) + + if stt_engine_id is not None: + stt_engine = stt.async_get_speech_to_text_engine(hass, stt_engine_id) + if stt_engine is None: + stt_engine_id = None + + if stt_engine: stt_languages = language_util.matches( pipeline_language, stt_engine.supported_languages, @@ -125,12 +130,15 @@ async def _async_create_default_pipeline( ) stt_engine_id = None - if (tts_engine_id := tts.async_default_engine(hass)) is not None and ( - tts_engine := tts.get_engine_instance( - hass, - tts_engine_id, - ) - ): + if tts_engine_id is None: + tts_engine_id = tts.async_default_engine(hass) + + if tts_engine_id is not None: + tts_engine = tts.get_engine_instance(hass, tts_engine_id) + if tts_engine is None: + tts_engine_id = None + + if tts_engine: tts_languages = language_util.matches( pipeline_language, tts_engine.supported_languages, @@ -152,19 +160,50 @@ async def _async_create_default_pipeline( if stt_engine_id == "cloud" and tts_engine_id == "cloud": pipeline_name = "Home Assistant Cloud" - return await pipeline_store.async_create_item( - { - "conversation_engine": conversation.HOME_ASSISTANT_AGENT, - "conversation_language": conversation_language, - "language": hass.config.language, - "name": pipeline_name, - "stt_engine": stt_engine_id, - "stt_language": stt_language, - "tts_engine": tts_engine_id, - "tts_language": tts_language, - "tts_voice": tts_voice, - } + return { + "conversation_engine": conversation.HOME_ASSISTANT_AGENT, + "conversation_language": conversation_language, + "language": hass.config.language, + "name": pipeline_name, + "stt_engine": stt_engine_id, + "stt_language": stt_language, + "tts_engine": tts_engine_id, + "tts_language": tts_language, + "tts_voice": tts_voice, + } + + +async def _async_create_default_pipeline( + hass: HomeAssistant, pipeline_store: PipelineStorageCollection +) -> Pipeline: + """Create a default pipeline. + + The default pipeline will use the homeassistant conversation agent and the + default stt / tts engines. + """ + pipeline_settings = await _async_resolve_default_pipeline_settings(hass, None, None) + return await pipeline_store.async_create_item(pipeline_settings) + + +async def async_create_default_pipeline( + hass: HomeAssistant, stt_engine_id: str, tts_engine_id: str +) -> Pipeline | None: + """Create a pipeline with default settings. + + The default pipeline will use the homeassistant conversation agent and the + specified stt / tts engines. + """ + pipeline_data: PipelineData = hass.data[DOMAIN] + pipeline_store = pipeline_data.pipeline_store + pipeline_settings = await _async_resolve_default_pipeline_settings( + hass, stt_engine_id, tts_engine_id ) + if ( + pipeline_settings["stt_engine"] != stt_engine_id + or pipeline_settings["tts_engine"] != tts_engine_id + ): + return None + return await pipeline_store.async_create_item(pipeline_settings) @callback @@ -181,6 +220,14 @@ def async_get_pipeline( return pipeline_data.pipeline_store.data.get(pipeline_id) +@callback +def async_get_pipelines(hass: HomeAssistant) -> Iterable[Pipeline]: + """Get all pipelines.""" + pipeline_data: PipelineData = hass.data[DOMAIN] + + return pipeline_data.pipeline_store.data.values() + + class PipelineEventType(StrEnum): """Event types emitted during a pipeline run.""" diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index dbf0ca7e08a..ebfaf6b0baa 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -238,9 +238,27 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: await prefs.async_initialize() # Initialize Cloud + loaded = False + + async def _discover_platforms(): + """Discover platforms.""" + nonlocal loaded + + # Prevent multiple discovery + if loaded: + return + loaded = True + + await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config) + await async_load_platform(hass, Platform.STT, DOMAIN, {}, config) + await async_load_platform(hass, Platform.TTS, DOMAIN, {}, config) + websession = async_get_clientsession(hass) - client = CloudClient(hass, prefs, websession, alexa_conf, google_conf) + client = CloudClient( + hass, prefs, websession, alexa_conf, google_conf, _discover_platforms + ) cloud = hass.data[DOMAIN] = Cloud(client, **kwargs) + cloud.iot.register_on_connect(client.on_cloud_connected) async def _shutdown(event): """Shutdown event.""" @@ -262,8 +280,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: hass, DOMAIN, SERVICE_REMOTE_DISCONNECT, _service_handler ) - loaded = False - async def async_startup_repairs(_=None) -> None: """Create repair issues after startup.""" if not cloud.is_logged_in: @@ -272,23 +288,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: if subscription_info := await async_subscription_info(cloud): async_manage_legacy_subscription_issue(hass, subscription_info) - async def _discover_platforms(): - """Discover platforms.""" - nonlocal loaded - - # Prevent multiple discovery - if loaded: - return - loaded = True - - await async_load_platform(hass, Platform.BINARY_SENSOR, DOMAIN, {}, config) - await async_load_platform(hass, Platform.STT, DOMAIN, {}, config) - await async_load_platform(hass, Platform.TTS, DOMAIN, {}, config) - async def _on_connect(): """Handle cloud connect.""" - await _discover_platforms() - async_dispatcher_send( hass, SIGNAL_CLOUD_CONNECTION_STATE, CloudConnectionState.CLOUD_CONNECTED ) diff --git a/homeassistant/components/cloud/client.py b/homeassistant/components/cloud/client.py index 900779f6b01..7a0fada7e15 100644 --- a/homeassistant/components/cloud/client.py +++ b/homeassistant/components/cloud/client.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from collections.abc import Callable, Coroutine from http import HTTPStatus import logging from pathlib import Path @@ -10,7 +11,13 @@ from typing import Any import aiohttp from hass_nabucasa.client import CloudClient as Interface -from homeassistant.components import google_assistant, persistent_notification, webhook +from homeassistant.components import ( + assist_pipeline, + conversation, + google_assistant, + persistent_notification, + webhook, +) from homeassistant.components.alexa import ( errors as alexa_errors, smart_home as alexa_smart_home, @@ -36,6 +43,7 @@ class CloudClient(Interface): websession: aiohttp.ClientSession, alexa_user_config: dict[str, Any], google_user_config: dict[str, Any], + on_started_cb: Callable[[], Coroutine[Any, Any, None]], ) -> None: """Initialize client interface to Cloud.""" self._hass = hass @@ -48,6 +56,10 @@ class CloudClient(Interface): self._alexa_config_init_lock = asyncio.Lock() self._google_config_init_lock = asyncio.Lock() self._relayer_region: str | None = None + self._on_started_cb = on_started_cb + self.cloud_pipeline = self._cloud_assist_pipeline() + self.stt_platform_loaded = asyncio.Event() + self.tts_platform_loaded = asyncio.Event() @property def base_path(self) -> Path: @@ -136,8 +148,24 @@ class CloudClient(Interface): return self._google_config - async def cloud_started(self) -> None: - """When cloud is started.""" + def _cloud_assist_pipeline(self) -> str | None: + """Return the ID of a cloud-enabled assist pipeline or None.""" + for pipeline in assist_pipeline.async_get_pipelines(self._hass): + if ( + pipeline.conversation_engine == conversation.HOME_ASSISTANT_AGENT + and pipeline.stt_engine == DOMAIN + and pipeline.tts_engine == DOMAIN + ): + return pipeline.id + return None + + async def create_cloud_assist_pipeline(self) -> None: + """Create a cloud-enabled assist pipeline.""" + await assist_pipeline.async_create_default_pipeline(self._hass, DOMAIN, DOMAIN) + self.cloud_pipeline = self._cloud_assist_pipeline() + + async def on_cloud_connected(self) -> None: + """When cloud is connected.""" is_new_user = await self.prefs.async_set_username(self.cloud.username) async def enable_alexa(_): @@ -181,6 +209,14 @@ class CloudClient(Interface): if tasks: await asyncio.gather(*(task(None) for task in tasks)) + async def cloud_started(self) -> None: + """When cloud is started.""" + await self._on_started_cb() + await asyncio.gather( + self.stt_platform_loaded.wait(), + self.tts_platform_loaded.wait(), + ) + async def cloud_stopped(self) -> None: """When the cloud is stopped.""" diff --git a/homeassistant/components/cloud/http_api.py b/homeassistant/components/cloud/http_api.py index aef4efdb7a4..82eb64b3a3a 100644 --- a/homeassistant/components/cloud/http_api.py +++ b/homeassistant/components/cloud/http_api.py @@ -186,7 +186,11 @@ class CloudLoginView(HomeAssistantView): cloud = hass.data[DOMAIN] await cloud.login(data["email"], data["password"]) - return self.json({"success": True}) + if cloud.client.cloud_pipeline is None: + await cloud.client.create_cloud_assist_pipeline() + return self.json( + {"success": True, "cloud_pipeline": cloud.client.cloud_pipeline} + ) class CloudLogoutView(HomeAssistantView): diff --git a/homeassistant/components/cloud/manifest.json b/homeassistant/components/cloud/manifest.json index a5486888a01..25af2b8afd0 100644 --- a/homeassistant/components/cloud/manifest.json +++ b/homeassistant/components/cloud/manifest.json @@ -3,7 +3,7 @@ "name": "Home Assistant Cloud", "after_dependencies": ["google_assistant", "alexa"], "codeowners": ["@home-assistant/cloud"], - "dependencies": ["homeassistant", "http", "webhook"], + "dependencies": ["assist_pipeline", "homeassistant", "http", "webhook"], "documentation": "https://www.home-assistant.io/integrations/cloud", "integration_type": "system", "iot_class": "cloud_push", diff --git a/homeassistant/components/cloud/stt.py b/homeassistant/components/cloud/stt.py index 9499bc0cc11..8ccb932c545 100644 --- a/homeassistant/components/cloud/stt.py +++ b/homeassistant/components/cloud/stt.py @@ -28,7 +28,9 @@ async def async_get_engine(hass, config, discovery_info=None): """Set up Cloud speech component.""" cloud: Cloud = hass.data[DOMAIN] - return CloudProvider(cloud) + cloud_provider = CloudProvider(cloud) + cloud.client.stt_platform_loaded.set() + return cloud_provider class CloudProvider(Provider): diff --git a/homeassistant/components/cloud/tts.py b/homeassistant/components/cloud/tts.py index 6c29a1768b9..58e918b9679 100644 --- a/homeassistant/components/cloud/tts.py +++ b/homeassistant/components/cloud/tts.py @@ -63,7 +63,9 @@ async def async_get_engine(hass, config, discovery_info=None): language = config[CONF_LANG] gender = config[ATTR_GENDER] - return CloudProvider(cloud, language, gender) + cloud_provider = CloudProvider(cloud, language, gender) + cloud.client.tts_platform_loaded.set() + return cloud_provider class CloudProvider(Provider): diff --git a/tests/components/assist_pipeline/test_pipeline.py b/tests/components/assist_pipeline/test_pipeline.py index 6b2fa60102d..4c71b4aedbd 100644 --- a/tests/components/assist_pipeline/test_pipeline.py +++ b/tests/components/assist_pipeline/test_pipeline.py @@ -1,6 +1,6 @@ """Websocket tests for Voice Assistant integration.""" from typing import Any -from unittest.mock import AsyncMock, patch +from unittest.mock import ANY, AsyncMock, patch import pytest @@ -11,7 +11,9 @@ from homeassistant.components.assist_pipeline.pipeline import ( Pipeline, PipelineData, PipelineStorageCollection, + async_create_default_pipeline, async_get_pipeline, + async_get_pipelines, ) from homeassistant.core import HomeAssistant from homeassistant.helpers.storage import Store @@ -143,6 +145,31 @@ async def test_loading_datasets_from_storage( assert store.async_get_preferred_item() == "01GX8ZWBAQYWNB1XV3EXEZ75DY" +async def test_create_default_pipeline( + hass: HomeAssistant, init_supporting_components +) -> None: + """Test async_create_default_pipeline.""" + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store + assert len(store.data) == 1 + + assert await async_create_default_pipeline(hass, "bla", "bla") is None + assert await async_create_default_pipeline(hass, "test", "test") == Pipeline( + conversation_engine="homeassistant", + conversation_language="en", + id=ANY, + language="en", + name="Home Assistant", + stt_engine="test", + stt_language="en-US", + tts_engine="test", + tts_language="en-US", + tts_voice="james_earl_jones", + ) + + async def test_get_pipeline(hass: HomeAssistant) -> None: """Test async_get_pipeline.""" assert await async_setup_component(hass, "assist_pipeline", {}) @@ -159,6 +186,31 @@ async def test_get_pipeline(hass: HomeAssistant) -> None: assert pipeline is async_get_pipeline(hass, pipeline.id) +async def test_get_pipelines(hass: HomeAssistant) -> None: + """Test async_get_pipelines.""" + assert await async_setup_component(hass, "assist_pipeline", {}) + + pipeline_data: PipelineData = hass.data[DOMAIN] + store = pipeline_data.pipeline_store + assert len(store.data) == 1 + + pipelines = async_get_pipelines(hass) + assert list(pipelines) == [ + Pipeline( + conversation_engine="homeassistant", + conversation_language="en", + id=ANY, + language="en", + name="Home Assistant", + stt_engine=None, + stt_language=None, + tts_engine=None, + tts_language=None, + tts_voice=None, + ) + ] + + @pytest.mark.parametrize( ("ha_language", "ha_country", "conv_language", "pipeline_language"), [ diff --git a/tests/components/cloud/conftest.py b/tests/components/cloud/conftest.py index e16fb63b34a..93d3dc35bc3 100644 --- a/tests/components/cloud/conftest.py +++ b/tests/components/cloud/conftest.py @@ -8,6 +8,13 @@ from homeassistant.components.cloud import const, prefs from . import mock_cloud, mock_cloud_prefs +# Prevent TTS cache from being created +from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import + init_cache_dir_side_effect, + mock_get_cache_files, + mock_init_cache_dir, +) + @pytest.fixture(autouse=True) def mock_user_data(): diff --git a/tests/components/cloud/test_client.py b/tests/components/cloud/test_client.py index 1afe9956288..0e053941f51 100644 --- a/tests/components/cloud/test_client.py +++ b/tests/components/cloud/test_client.py @@ -6,6 +6,11 @@ import aiohttp from aiohttp import web import pytest +from homeassistant.components.assist_pipeline import ( + Pipeline, + async_get_pipeline, + async_get_pipelines, +) from homeassistant.components.cloud import DOMAIN from homeassistant.components.cloud.client import CloudClient from homeassistant.components.cloud.const import ( @@ -298,23 +303,31 @@ async def test_google_config_should_2fa( assert not gconf.should_2fa(state) -async def test_set_username(hass: HomeAssistant) -> None: +@patch( + "homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines", + return_value=[], +) +async def test_set_username(async_get_pipelines, hass: HomeAssistant) -> None: """Test we set username during login.""" prefs = MagicMock( alexa_enabled=False, google_enabled=False, async_set_username=AsyncMock(return_value=None), ) - client = CloudClient(hass, prefs, None, {}, {}) + client = CloudClient(hass, prefs, None, {}, {}, AsyncMock()) client.cloud = MagicMock(is_logged_in=True, username="mock-username") - await client.cloud_started() + await client.on_cloud_connected() assert len(prefs.async_set_username.mock_calls) == 1 assert prefs.async_set_username.mock_calls[0][1][0] == "mock-username" +@patch( + "homeassistant.components.cloud.client.assist_pipeline.async_get_pipelines", + return_value=[], +) async def test_login_recovers_bad_internet( - hass: HomeAssistant, caplog: pytest.LogCaptureFixture + async_get_pipelines, hass: HomeAssistant, caplog: pytest.LogCaptureFixture ) -> None: """Test Alexa can recover bad auth.""" prefs = Mock( @@ -322,12 +335,12 @@ async def test_login_recovers_bad_internet( google_enabled=False, async_set_username=AsyncMock(return_value=None), ) - client = CloudClient(hass, prefs, None, {}, {}) + client = CloudClient(hass, prefs, None, {}, {}, AsyncMock()) client.cloud = Mock() client._alexa_config = Mock( async_enable_proactive_mode=Mock(side_effect=aiohttp.ClientError) ) - await client.cloud_started() + await client.on_cloud_connected() assert len(client._alexa_config.async_enable_proactive_mode.mock_calls) == 1 assert "Unable to activate Alexa Report State" in caplog.text @@ -354,3 +367,29 @@ async def test_system_msg(hass: HomeAssistant) -> None: assert response is None assert cloud.client.relayer_region == "xx-earth-616" + + +async def test_create_cloud_assist_pipeline( + hass: HomeAssistant, mock_cloud_setup, mock_cloud_login +) -> None: + """Test creating a cloud enabled assist pipeline.""" + cloud_client: CloudClient = hass.data[DOMAIN].client + await cloud_client.cloud_started() + assert cloud_client.cloud_pipeline is None + assert len(async_get_pipelines(hass)) == 1 + + await cloud_client.create_cloud_assist_pipeline() + assert cloud_client.cloud_pipeline is not None + assert len(async_get_pipelines(hass)) == 2 + assert async_get_pipeline(hass, cloud_client.cloud_pipeline) == Pipeline( + conversation_engine="homeassistant", + conversation_language="en", + id=cloud_client.cloud_pipeline, + language="en", + name="Home Assistant Cloud", + stt_engine="cloud", + stt_language="en-US", + tts_engine="cloud", + tts_language="en-US", + tts_voice="JennyNeural", + ) diff --git a/tests/components/cloud/test_http_api.py b/tests/components/cloud/test_http_api.py index f748a6981ad..6dfca339182 100644 --- a/tests/components/cloud/test_http_api.py +++ b/tests/components/cloud/test_http_api.py @@ -105,7 +105,14 @@ async def test_google_actions_sync_fails( async def test_login_view(hass: HomeAssistant, cloud_client) -> None: """Test logging in.""" - hass.data["cloud"] = MagicMock(login=AsyncMock()) + create_cloud_assist_pipeline_mock = AsyncMock() + hass.data["cloud"] = MagicMock( + login=AsyncMock(), + client=Mock( + cloud_pipeline="12345", + create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock, + ), + ) req = await cloud_client.post( "/api/cloud/login", json={"email": "my_username", "password": "my_password"} @@ -113,7 +120,29 @@ async def test_login_view(hass: HomeAssistant, cloud_client) -> None: assert req.status == HTTPStatus.OK result = await req.json() - assert result == {"success": True} + assert result == {"success": True, "cloud_pipeline": "12345"} + create_cloud_assist_pipeline_mock.assert_not_awaited() + + +async def test_login_view_create_pipeline(hass: HomeAssistant, cloud_client) -> None: + """Test logging in when no assist pipeline is available.""" + create_cloud_assist_pipeline_mock = AsyncMock() + hass.data["cloud"] = MagicMock( + login=AsyncMock(), + client=Mock( + cloud_pipeline=None, + create_cloud_assist_pipeline=create_cloud_assist_pipeline_mock, + ), + ) + + req = await cloud_client.post( + "/api/cloud/login", json={"email": "my_username", "password": "my_password"} + ) + + assert req.status == HTTPStatus.OK + result = await req.json() + assert result == {"success": True, "cloud_pipeline": None} + create_cloud_assist_pipeline_mock.assert_awaited_once() async def test_login_view_random_exception(cloud_client) -> None: diff --git a/tests/components/cloud/test_init.py b/tests/components/cloud/test_init.py index b3ec96f08e6..9f6631f0cb9 100644 --- a/tests/components/cloud/test_init.py +++ b/tests/components/cloud/test_init.py @@ -2,6 +2,7 @@ from typing import Any from unittest.mock import patch +from hass_nabucasa import Cloud import pytest from homeassistant.components import cloud @@ -134,9 +135,9 @@ async def test_setup_existing_cloud_user( async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None: """Test cloud on connect triggers.""" - cl = hass.data["cloud"] + cl: Cloud = hass.data["cloud"] - assert len(cl.iot._on_connect) == 3 + assert len(cl.iot._on_connect) == 4 assert len(hass.states.async_entity_ids("binary_sensor")) == 0 @@ -152,6 +153,11 @@ async def test_on_connect(hass: HomeAssistant, mock_cloud_fixture) -> None: await cl.iot._on_connect[-1]() await hass.async_block_till_done() + assert len(hass.states.async_entity_ids("binary_sensor")) == 0 + + await cl.client.cloud_started() + await hass.async_block_till_done() + assert len(hass.states.async_entity_ids("binary_sensor")) == 1 with patch("homeassistant.helpers.discovery.async_load_platform") as mock_load: diff --git a/tests/components/stt/test_init.py b/tests/components/stt/test_init.py index 45983d08868..5a7e93a72a2 100644 --- a/tests/components/stt/test_init.py +++ b/tests/components/stt/test_init.py @@ -511,6 +511,12 @@ async def test_get_engine_legacy( TEST_DOMAIN, async_get_engine=AsyncMock(return_value=mock_provider), ) + mock_stt_platform( + hass, + tmp_path, + "cloud", + async_get_engine=AsyncMock(return_value=mock_provider), + ) assert await async_setup_component( hass, "stt", {"stt": [{"platform": TEST_DOMAIN}, {"platform": "cloud"}]} )