Fix threading in get_test_home_assistant test helper (#124056)

This commit is contained in:
Erik Montnemery 2024-08-16 16:59:33 +02:00 committed by GitHub
parent 06209dd94c
commit 115c5d1704
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,7 +13,12 @@ from collections.abc import (
Mapping, Mapping,
Sequence, Sequence,
) )
from contextlib import asynccontextmanager, contextmanager, suppress from contextlib import (
AbstractAsyncContextManager,
asynccontextmanager,
contextmanager,
suppress,
)
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from enum import Enum from enum import Enum
import functools as ft import functools as ft
@ -177,24 +182,36 @@ def get_test_config_dir(*add_path):
@contextmanager @contextmanager
def get_test_home_assistant() -> Generator[HomeAssistant]: def get_test_home_assistant() -> Generator[HomeAssistant]:
"""Return a Home Assistant object pointing at test config directory.""" """Return a Home Assistant object pointing at test config directory."""
loop = asyncio.new_event_loop() hass_created_event = threading.Event()
asyncio.set_event_loop(loop)
context_manager = async_test_home_assistant(loop)
hass = loop.run_until_complete(context_manager.__aenter__())
loop_stop_event = threading.Event() loop_stop_event = threading.Event()
context_manager: AbstractAsyncContextManager = None
hass: HomeAssistant = None
loop: asyncio.AbstractEventLoop = None
orig_stop: Callable = None
def run_loop() -> None: def run_loop() -> None:
"""Run event loop.""" """Create and run event loop."""
nonlocal context_manager, hass, loop, orig_stop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
context_manager = async_test_home_assistant(loop)
hass = loop.run_until_complete(context_manager.__aenter__())
orig_stop = hass.stop
hass._stopped = Mock(set=loop.stop)
hass.start = start_hass
hass.stop = stop_hass
loop._thread_ident = threading.get_ident() loop._thread_ident = threading.get_ident()
hass_created_event.set()
hass.loop_thread_id = loop._thread_ident hass.loop_thread_id = loop._thread_ident
loop.run_forever() loop.run_forever()
loop_stop_event.set() loop_stop_event.set()
orig_stop = hass.stop
hass._stopped = Mock(set=loop.stop)
def start_hass(*mocks: Any) -> None: def start_hass(*mocks: Any) -> None:
"""Start hass.""" """Start hass."""
asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result() asyncio.run_coroutine_threadsafe(hass.async_start(), loop).result()
@ -204,11 +221,10 @@ def get_test_home_assistant() -> Generator[HomeAssistant]:
orig_stop() orig_stop()
loop_stop_event.wait() loop_stop_event.wait()
hass.start = start_hass
hass.stop = stop_hass
threading.Thread(name="LoopThread", target=run_loop, daemon=False).start() threading.Thread(name="LoopThread", target=run_loop, daemon=False).start()
hass_created_event.wait()
try: try:
yield hass yield hass
finally: finally: