Add overload for async singleton call with HassKey (#134059)

This commit is contained in:
Marc Mueller 2025-01-17 19:22:48 +01:00 committed by GitHub
parent 2ec971ad9d
commit abc256fb3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 11 deletions

View File

@ -12,6 +12,7 @@ from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.singleton import singleton from homeassistant.helpers.singleton import singleton
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.util.hass_dict import HassKey
from .const import DOMAIN from .const import DOMAIN
from .coordinator import ESPHomeDashboardCoordinator from .coordinator import ESPHomeDashboardCoordinator
@ -19,7 +20,9 @@ from .coordinator import ESPHomeDashboardCoordinator
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
KEY_DASHBOARD_MANAGER = "esphome_dashboard_manager" KEY_DASHBOARD_MANAGER: HassKey[ESPHomeDashboardManager] = HassKey(
"esphome_dashboard_manager"
)
STORAGE_KEY = "esphome.dashboard" STORAGE_KEY = "esphome.dashboard"
STORAGE_VERSION = 1 STORAGE_VERSION = 1
@ -33,7 +36,7 @@ async def async_setup(hass: HomeAssistant) -> None:
await async_get_or_create_dashboard_manager(hass) await async_get_or_create_dashboard_manager(hass)
@singleton(KEY_DASHBOARD_MANAGER) @singleton(KEY_DASHBOARD_MANAGER, async_=True)
async def async_get_or_create_dashboard_manager( async def async_get_or_create_dashboard_manager(
hass: HomeAssistant, hass: HomeAssistant,
) -> ESPHomeDashboardManager: ) -> ESPHomeDashboardManager:
@ -140,7 +143,7 @@ def async_get_dashboard(hass: HomeAssistant) -> ESPHomeDashboardCoordinator | No
where manager can be an asyncio.Event instead of the actual manager where manager can be an asyncio.Event instead of the actual manager
because the singleton decorator is not yet done. because the singleton decorator is not yet done.
""" """
manager: ESPHomeDashboardManager | None = hass.data.get(KEY_DASHBOARD_MANAGER) manager = hass.data.get(KEY_DASHBOARD_MANAGER)
return manager.async_get() if manager else None return manager.async_get() if manager else None

View File

@ -3,15 +3,22 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Callable, Coroutine
import functools import functools
from typing import Any, cast, overload from typing import Any, Literal, assert_type, cast, overload
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.hass_dict import HassKey from homeassistant.util.hass_dict import HassKey
type _FuncType[_T] = Callable[[HomeAssistant], _T] type _FuncType[_T] = Callable[[HomeAssistant], _T]
type _Coro[_T] = Coroutine[Any, Any, _T]
@overload
def singleton[_T](
data_key: HassKey[_T], *, async_: Literal[True]
) -> Callable[[_FuncType[_Coro[_T]]], _FuncType[_Coro[_T]]]: ...
@overload @overload
@ -24,29 +31,37 @@ def singleton[_T](
def singleton[_T](data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ... def singleton[_T](data_key: str) -> Callable[[_FuncType[_T]], _FuncType[_T]]: ...
def singleton[_T](data_key: Any) -> Callable[[_FuncType[_T]], _FuncType[_T]]: def singleton[_S, _T, _U](
data_key: Any, *, async_: bool = False
) -> Callable[[_FuncType[_S]], _FuncType[_S]]:
"""Decorate a function that should be called once per instance. """Decorate a function that should be called once per instance.
Result will be cached and simultaneous calls will be handled. Result will be cached and simultaneous calls will be handled.
""" """
def wrapper(func: _FuncType[_T]) -> _FuncType[_T]: @overload
def wrapper(func: _FuncType[_Coro[_T]]) -> _FuncType[_Coro[_T]]: ...
@overload
def wrapper(func: _FuncType[_U]) -> _FuncType[_U]: ...
def wrapper(func: _FuncType[_Coro[_T] | _U]) -> _FuncType[_Coro[_T] | _U]:
"""Wrap a function with caching logic.""" """Wrap a function with caching logic."""
if not asyncio.iscoroutinefunction(func): if not asyncio.iscoroutinefunction(func):
@functools.lru_cache(maxsize=1) @functools.lru_cache(maxsize=1)
@bind_hass @bind_hass
@functools.wraps(func) @functools.wraps(func)
def wrapped(hass: HomeAssistant) -> _T: def wrapped(hass: HomeAssistant) -> _U:
if data_key not in hass.data: if data_key not in hass.data:
hass.data[data_key] = func(hass) hass.data[data_key] = func(hass)
return cast(_T, hass.data[data_key]) return cast(_U, hass.data[data_key])
return wrapped return wrapped
@bind_hass @bind_hass
@functools.wraps(func) @functools.wraps(func)
async def async_wrapped(hass: HomeAssistant) -> Any: async def async_wrapped(hass: HomeAssistant) -> _T:
if data_key not in hass.data: if data_key not in hass.data:
evt = hass.data[data_key] = asyncio.Event() evt = hass.data[data_key] = asyncio.Event()
result = await func(hass) result = await func(hass)
@ -62,6 +77,45 @@ def singleton[_T](data_key: Any) -> Callable[[_FuncType[_T]], _FuncType[_T]]:
return cast(_T, obj_or_evt) return cast(_T, obj_or_evt)
return async_wrapped # type: ignore[return-value] return async_wrapped
return wrapper return wrapper
async def _test_singleton_typing(hass: HomeAssistant) -> None:
"""Test singleton overloads work as intended.
This is tested during the mypy run. Do not move it to 'tests'!
"""
# Test HassKey
key = HassKey[int]("key")
@singleton(key)
def func(hass: HomeAssistant) -> int:
return 2
@singleton(key, async_=True)
async def async_func(hass: HomeAssistant) -> int:
return 2
assert_type(func(hass), int)
assert_type(await async_func(hass), int)
# Test invalid use of 'async_' with sync function
@singleton(key, async_=True) # type: ignore[arg-type]
def func_error(hass: HomeAssistant) -> int:
return 2
# Test string key
other_key = "key"
@singleton(other_key)
def func2(hass: HomeAssistant) -> str:
return ""
@singleton(other_key)
async def async_func2(hass: HomeAssistant) -> str:
return ""
assert_type(func2(hass), str)
assert_type(await async_func2(hass), str)