diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index 3f8df691573..37ab3123919 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -1,9 +1,10 @@ """esphome session fixtures.""" from __future__ import annotations +from asyncio import Event from unittest.mock import AsyncMock, Mock, patch -from aioesphomeapi import APIClient, APIVersion, DeviceInfo +from aioesphomeapi import APIClient, APIVersion, DeviceInfo, ReconnectLogic import pytest from zeroconf import Zeroconf @@ -160,10 +161,18 @@ async def mock_voice_assistant_entry( mock_client.device_info = AsyncMock(return_value=device_info) mock_client.subscribe_voice_assistant = AsyncMock(return_value=Mock()) - await hass.config_entries.async_setup(entry.entry_id) - await hass.async_block_till_done() - await hass.async_block_till_done() - await hass.async_block_till_done() + try_connect_done = Event() + real_try_connect = ReconnectLogic._try_connect + + async def mock_try_connect(self): + """Set an event when ReconnectLogic._try_connect has been awaited.""" + result = await real_try_connect(self) + try_connect_done.set() + return result + + with patch.object(ReconnectLogic, "_try_connect", mock_try_connect): + await hass.config_entries.async_setup(entry.entry_id) + await try_connect_done.wait() return entry