Introduce a singleton decorator (#34803)

This commit is contained in:
Paulus Schoutsen 2020-04-30 16:47:14 -07:00 committed by GitHub
parent 76f392476b
commit 6056753a9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 44 deletions

View File

@ -83,8 +83,8 @@ if TYPE_CHECKING:
block_async_io.enable() block_async_io.enable()
fix_threading_exception_logging() fix_threading_exception_logging()
# pylint: disable=invalid-name
T = TypeVar("T") T = TypeVar("T")
# pylint: disable=invalid-name
CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable) CALLABLE_T = TypeVar("CALLABLE_T", bound=Callable)
CALLBACK_TYPE = Callable[[], None] CALLBACK_TYPE = Callable[[], None]
# pylint: enable=invalid-name # pylint: enable=invalid-name

View File

@ -1,8 +1,7 @@
"""Provide a way to connect entities belonging to one device.""" """Provide a way to connect entities belonging to one device."""
from asyncio import Event
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional
import uuid import uuid
import attr import attr
@ -10,6 +9,7 @@ import attr
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from .singleton import singleton
from .typing import HomeAssistantType from .typing import HomeAssistantType
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs # mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
@ -356,26 +356,12 @@ class DeviceRegistry:
@bind_hass @bind_hass
@singleton(DATA_REGISTRY)
async def async_get_registry(hass: HomeAssistantType) -> DeviceRegistry: async def async_get_registry(hass: HomeAssistantType) -> DeviceRegistry:
"""Return device registry instance.""" """Create entity registry."""
reg_or_evt = hass.data.get(DATA_REGISTRY) reg = DeviceRegistry(hass)
await reg.async_load()
if not reg_or_evt: return reg
evt = hass.data[DATA_REGISTRY] = Event()
reg = DeviceRegistry(hass)
await reg.async_load()
hass.data[DATA_REGISTRY] = reg
evt.set()
return reg
if isinstance(reg_or_evt, Event):
evt = reg_or_evt
await evt.wait()
return cast(DeviceRegistry, hass.data.get(DATA_REGISTRY))
return cast(DeviceRegistry, reg_or_evt)
@callback @callback

View File

@ -7,7 +7,6 @@ The Entity Registry will persist itself 10 seconds after a new entity is
registered. Registering a new entity while a timer is in progress resets the registered. Registering a new entity while a timer is in progress resets the
timer. timer.
""" """
import asyncio
from collections import OrderedDict from collections import OrderedDict
import logging import logging
from typing import ( from typing import (
@ -39,6 +38,7 @@ from homeassistant.loader import bind_hass
from homeassistant.util import slugify from homeassistant.util import slugify
from homeassistant.util.yaml import load_yaml from homeassistant.util.yaml import load_yaml
from .singleton import singleton
from .typing import HomeAssistantType from .typing import HomeAssistantType
if TYPE_CHECKING: if TYPE_CHECKING:
@ -492,26 +492,12 @@ class EntityRegistry:
@bind_hass @bind_hass
@singleton(DATA_REGISTRY)
async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry: async def async_get_registry(hass: HomeAssistantType) -> EntityRegistry:
"""Return entity registry instance.""" """Create entity registry."""
reg_or_evt = hass.data.get(DATA_REGISTRY) reg = EntityRegistry(hass)
await reg.async_load()
if not reg_or_evt: return reg
evt = hass.data[DATA_REGISTRY] = asyncio.Event()
reg = EntityRegistry(hass)
await reg.async_load()
hass.data[DATA_REGISTRY] = reg
evt.set()
return reg
if isinstance(reg_or_evt, asyncio.Event):
evt = reg_or_evt
await evt.wait()
return cast(EntityRegistry, hass.data.get(DATA_REGISTRY))
return cast(EntityRegistry, reg_or_evt)
@callback @callback
@ -621,4 +607,4 @@ async def async_migrate_entries(
updates = entry_callback(entry) updates = entry_callback(entry)
if updates is not None: if updates is not None:
ent_reg.async_update_entity(entry.entity_id, **updates) # type: ignore ent_reg.async_update_entity(entry.entity_id, **updates)

View File

@ -0,0 +1,44 @@
"""Helper to help coordinating calls."""
import asyncio
import functools
from typing import Awaitable, Callable, TypeVar, cast
from homeassistant.core import HomeAssistant
T = TypeVar("T")
FUNC = Callable[[HomeAssistant], Awaitable[T]]
def singleton(data_key: str) -> Callable[[FUNC], FUNC]:
"""Decorate a function that should be called once per instance.
Result will be cached and simultaneous calls will be handled.
"""
def wrapper(func: FUNC) -> FUNC:
"""Wrap a function with caching logic."""
@functools.wraps(func)
async def wrapped(hass: HomeAssistant) -> T:
obj_or_evt = hass.data.get(data_key)
if not obj_or_evt:
evt = hass.data[data_key] = asyncio.Event()
result = await func(hass)
hass.data[data_key] = result
evt.set()
return cast(T, result)
if isinstance(obj_or_evt, asyncio.Event):
evt = obj_or_evt
await evt.wait()
return cast(T, hass.data.get(data_key))
return cast(T, obj_or_evt)
return wrapped
return wrapper

View File

@ -8,7 +8,7 @@ persistent=no
extension-pkg-whitelist=ciso8601 extension-pkg-whitelist=ciso8601
[BASIC] [BASIC]
good-names=id,i,j,k,ex,Run,_,fp good-names=id,i,j,k,ex,Run,_,fp,T
[MESSAGES CONTROL] [MESSAGES CONTROL]
# Reasons disabled: # Reasons disabled: