From eb81a4204e1ec32d15519524f93d9dc726298c3e Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Tue, 26 Mar 2024 10:38:29 +0100 Subject: [PATCH] Allow string formatting for dispatcher SignalType (#114174) --- homeassistant/components/mqtt/discovery.py | 11 ++++++++--- homeassistant/helpers/discovery.py | 10 ++++++++-- homeassistant/helpers/dispatcher.py | 18 ++++++++++++++++-- tests/common.py | 5 ++++- tests/helpers/test_dispatcher.py | 22 ++++++++++++++++++++++ 5 files changed, 58 insertions(+), 8 deletions(-) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 0e1b0a799d6..f50a4e4a3f7 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -18,6 +18,7 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.data_entry_flow import FlowResultType import homeassistant.helpers.config_validation as cv from homeassistant.helpers.dispatcher import ( + SignalTypeFormat, async_dispatcher_connect, async_dispatcher_send, ) @@ -79,10 +80,14 @@ SUPPORTED_COMPONENTS = { "water_heater", } -MQTT_DISCOVERY_UPDATED = "mqtt_discovery_updated_{}" -MQTT_DISCOVERY_NEW = "mqtt_discovery_new_{}_{}" +MQTT_DISCOVERY_UPDATED: SignalTypeFormat[MQTTDiscoveryPayload] = SignalTypeFormat( + "mqtt_discovery_updated_{}" +) +MQTT_DISCOVERY_NEW: SignalTypeFormat[MQTTDiscoveryPayload] = SignalTypeFormat( + "mqtt_discovery_new_{}_{}" +) MQTT_DISCOVERY_NEW_COMPONENT = "mqtt_discovery_new_component" -MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}" +MQTT_DISCOVERY_DONE: SignalTypeFormat[Any] = SignalTypeFormat("mqtt_discovery_done_{}") TOPIC_BASE = "~" diff --git a/homeassistant/helpers/discovery.py b/homeassistant/helpers/discovery.py index 98419ae6bf2..0bad52dff08 100644 --- a/homeassistant/helpers/discovery.py +++ b/homeassistant/helpers/discovery.py @@ -15,10 +15,16 @@ from homeassistant import core, setup from homeassistant.const import Platform from homeassistant.loader import bind_hass -from .dispatcher import async_dispatcher_connect, async_dispatcher_send +from .dispatcher import ( + SignalTypeFormat, + async_dispatcher_connect, + async_dispatcher_send, +) from .typing import ConfigType, DiscoveryInfoType -SIGNAL_PLATFORM_DISCOVERED = "discovery.platform_discovered_{}" +SIGNAL_PLATFORM_DISCOVERED: SignalTypeFormat[DiscoveryDict] = SignalTypeFormat( + "discovery.platform_discovered_{}" +) EVENT_LOAD_PLATFORM = "load_platform.{}" ATTR_PLATFORM = "platform" ATTR_DISCOVERED = "discovered" diff --git a/homeassistant/helpers/dispatcher.py b/homeassistant/helpers/dispatcher.py index 22d9c3bbab8..4633e81c78b 100644 --- a/homeassistant/helpers/dispatcher.py +++ b/homeassistant/helpers/dispatcher.py @@ -20,8 +20,8 @@ DATA_DISPATCHER = "dispatcher" @dataclass(frozen=True) -class SignalType(Generic[*_Ts]): - """Generic string class for signal to improve typing.""" +class _SignalTypeBase(Generic[*_Ts]): + """Generic base class for SignalType.""" name: str @@ -40,6 +40,20 @@ class SignalType(Generic[*_Ts]): return False +@dataclass(frozen=True, eq=False) +class SignalType(_SignalTypeBase[*_Ts]): + """Generic string class for signal to improve typing.""" + + +@dataclass(frozen=True, eq=False) +class SignalTypeFormat(_SignalTypeBase[*_Ts]): + """Generic string class for signal. Requires call to 'format' before use.""" + + def format(self, *args: Any, **kwargs: Any) -> SignalType[*_Ts]: + """Format name and return new SignalType instance.""" + return SignalType(self.name.format(*args, **kwargs)) + + _DispatcherDataType = dict[ SignalType[*_Ts] | str, dict[ diff --git a/tests/common.py b/tests/common.py index 5743b26ef62..c0733a7642b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -76,6 +76,7 @@ from homeassistant.helpers import ( translation, ) from homeassistant.helpers.dispatcher import ( + SignalType, async_dispatcher_connect, async_dispatcher_send, ) @@ -1497,7 +1498,9 @@ def async_capture_events(hass: HomeAssistant, event_name: str) -> list[Event]: @callback -def async_mock_signal(hass: HomeAssistant, signal: str) -> list[tuple[Any]]: +def async_mock_signal( + hass: HomeAssistant, signal: SignalType[Any] | str +) -> list[tuple[Any]]: """Catch all dispatches to a signal.""" calls = [] diff --git a/tests/helpers/test_dispatcher.py b/tests/helpers/test_dispatcher.py index 3dd708906b9..1e1abe6e154 100644 --- a/tests/helpers/test_dispatcher.py +++ b/tests/helpers/test_dispatcher.py @@ -7,6 +7,7 @@ import pytest from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.dispatcher import ( SignalType, + SignalTypeFormat, async_dispatcher_connect, async_dispatcher_send, ) @@ -58,6 +59,27 @@ async def test_signal_type(hass: HomeAssistant) -> None: assert calls == [("Hello", 2), ("World", 3), ("x", 4)] +async def test_signal_type_format(hass: HomeAssistant) -> None: + """Test dispatcher with SignalType and format.""" + signal: SignalTypeFormat[str, int] = SignalTypeFormat("test-{}") + calls: list[tuple[str, int]] = [] + + def test_funct(data1: str, data2: int) -> None: + calls.append((data1, data2)) + + async_dispatcher_connect(hass, signal.format("unique-id"), test_funct) + async_dispatcher_send(hass, signal.format("unique-id"), "Hello", 2) + await hass.async_block_till_done() + + assert calls == [("Hello", 2)] + + # Test compatibility with string keys + async_dispatcher_send(hass, "test-{}".format("unique-id"), "x", 4) + await hass.async_block_till_done() + + assert calls == [("Hello", 2), ("x", 4)] + + async def test_simple_function_unsub(hass: HomeAssistant) -> None: """Test simple function (executor) and unsub.""" calls1 = []