From 81514b0d1cb4c19f5eeef3b1e212f7339d3207a2 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Fri, 23 Sep 2022 20:55:29 +0200 Subject: [PATCH] Move MQTT debug_info to dataclass (#78788) * Add MQTT debug_info to dataclass * Remove total attr, assign factory * Rename typed dict to MqttDebugInfo and use helper * Split entity and trigger debug info * Refactor * More rework --- homeassistant/components/mqtt/__init__.py | 1 - homeassistant/components/mqtt/debug_info.py | 133 +++++++++++--------- homeassistant/components/mqtt/mixins.py | 1 + homeassistant/components/mqtt/models.py | 32 ++++- tests/components/mqtt/test_common.py | 6 +- 5 files changed, 109 insertions(+), 64 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 03e4093bb01..306132c4f36 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -170,7 +170,6 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: websocket_api.async_register_command(hass, websocket_subscribe) websocket_api.async_register_command(hass, websocket_mqtt_info) - debug_info.initialize(hass) if conf: conf = dict(conf) diff --git a/homeassistant/components/mqtt/debug_info.py b/homeassistant/components/mqtt/debug_info.py index 17dbc27f0c4..5fae98eaea5 100644 --- a/homeassistant/components/mqtt/debug_info.py +++ b/homeassistant/components/mqtt/debug_info.py @@ -11,29 +11,26 @@ import attr from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er +from homeassistant.helpers.typing import DiscoveryInfoType from homeassistant.util import dt as dt_util from .const import ATTR_DISCOVERY_PAYLOAD, ATTR_DISCOVERY_TOPIC from .models import MessageCallbackType, PublishPayloadType +from .util import get_mqtt_data -DATA_MQTT_DEBUG_INFO = "mqtt_debug_info" STORED_MESSAGES = 10 -def initialize(hass: HomeAssistant): - """Initialize MQTT debug info.""" - hass.data[DATA_MQTT_DEBUG_INFO] = {"entities": {}, "triggers": {}} - - def log_messages( hass: HomeAssistant, entity_id: str ) -> Callable[[MessageCallbackType], MessageCallbackType]: """Wrap an MQTT message callback to support message logging.""" + debug_info_entities = get_mqtt_data(hass).debug_info_entities + def _log_message(msg): """Log message.""" - debug_info = hass.data[DATA_MQTT_DEBUG_INFO] - messages = debug_info["entities"][entity_id]["subscriptions"][ + messages = debug_info_entities[entity_id]["subscriptions"][ msg.subscribed_topic ]["messages"] if msg not in messages: @@ -72,8 +69,7 @@ def log_message( retain: bool, ) -> None: """Log an outgoing MQTT message.""" - debug_info = hass.data[DATA_MQTT_DEBUG_INFO] - entity_info = debug_info["entities"].setdefault( + entity_info = get_mqtt_data(hass).debug_info_entities.setdefault( entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}} ) if topic not in entity_info["transmitted"]: @@ -86,11 +82,14 @@ def log_message( entity_info["transmitted"][topic]["messages"].append(msg) -def add_subscription(hass, message_callback, subscription): +def add_subscription( + hass: HomeAssistant, + message_callback: MessageCallbackType, + subscription: str, +) -> None: """Prepare debug data for subscription.""" if entity_id := getattr(message_callback, "__entity_id", None): - debug_info = hass.data[DATA_MQTT_DEBUG_INFO] - entity_info = debug_info["entities"].setdefault( + entity_info = get_mqtt_data(hass).debug_info_entities.setdefault( entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}} ) if subscription not in entity_info["subscriptions"]: @@ -101,65 +100,81 @@ def add_subscription(hass, message_callback, subscription): entity_info["subscriptions"][subscription]["count"] += 1 -def remove_subscription(hass, message_callback, subscription): +def remove_subscription( + hass: HomeAssistant, + message_callback: MessageCallbackType, + subscription: str, +) -> None: """Remove debug data for subscription if it exists.""" - entity_id = getattr(message_callback, "__entity_id", None) - if entity_id and entity_id in hass.data[DATA_MQTT_DEBUG_INFO]["entities"]: - hass.data[DATA_MQTT_DEBUG_INFO]["entities"][entity_id]["subscriptions"][ - subscription - ]["count"] -= 1 - if not hass.data[DATA_MQTT_DEBUG_INFO]["entities"][entity_id]["subscriptions"][ - subscription - ]["count"]: - hass.data[DATA_MQTT_DEBUG_INFO]["entities"][entity_id]["subscriptions"].pop( - subscription - ) + if (entity_id := getattr(message_callback, "__entity_id", None)) and entity_id in ( + debug_info_entities := get_mqtt_data(hass).debug_info_entities + ): + debug_info_entities[entity_id]["subscriptions"][subscription]["count"] -= 1 + if not debug_info_entities[entity_id]["subscriptions"][subscription]["count"]: + debug_info_entities[entity_id]["subscriptions"].pop(subscription) -def add_entity_discovery_data(hass, discovery_data, entity_id): +def add_entity_discovery_data( + hass: HomeAssistant, discovery_data: DiscoveryInfoType, entity_id: str +) -> None: """Add discovery data.""" - debug_info = hass.data[DATA_MQTT_DEBUG_INFO] - entity_info = debug_info["entities"].setdefault( + entity_info = get_mqtt_data(hass).debug_info_entities.setdefault( entity_id, {"subscriptions": {}, "discovery_data": {}, "transmitted": {}} ) entity_info["discovery_data"] = discovery_data -def update_entity_discovery_data(hass, discovery_payload, entity_id): +def update_entity_discovery_data( + hass: HomeAssistant, discovery_payload: DiscoveryInfoType, entity_id: str +) -> None: """Update discovery data.""" - entity_info = hass.data[DATA_MQTT_DEBUG_INFO]["entities"][entity_id] - entity_info["discovery_data"][ATTR_DISCOVERY_PAYLOAD] = discovery_payload + assert ( + discovery_data := get_mqtt_data(hass).debug_info_entities[entity_id][ + "discovery_data" + ] + ) is not None + discovery_data[ATTR_DISCOVERY_PAYLOAD] = discovery_payload -def remove_entity_data(hass, entity_id): +def remove_entity_data(hass: HomeAssistant, entity_id: str) -> None: """Remove discovery data.""" - if entity_id in hass.data[DATA_MQTT_DEBUG_INFO]["entities"]: - hass.data[DATA_MQTT_DEBUG_INFO]["entities"].pop(entity_id) + if entity_id in (debug_info_entities := get_mqtt_data(hass).debug_info_entities): + debug_info_entities.pop(entity_id) -def add_trigger_discovery_data(hass, discovery_hash, discovery_data, device_id): +def add_trigger_discovery_data( + hass: HomeAssistant, + discovery_hash: tuple[str, str], + discovery_data: DiscoveryInfoType, + device_id: str, +) -> None: """Add discovery data.""" - debug_info = hass.data[DATA_MQTT_DEBUG_INFO] - debug_info["triggers"][discovery_hash] = { + get_mqtt_data(hass).debug_info_triggers[discovery_hash] = { "device_id": device_id, "discovery_data": discovery_data, } -def update_trigger_discovery_data(hass, discovery_hash, discovery_payload): +def update_trigger_discovery_data( + hass: HomeAssistant, + discovery_hash: tuple[str, str], + discovery_payload: DiscoveryInfoType, +) -> None: """Update discovery data.""" - trigger_info = hass.data[DATA_MQTT_DEBUG_INFO]["triggers"][discovery_hash] - trigger_info["discovery_data"][ATTR_DISCOVERY_PAYLOAD] = discovery_payload + get_mqtt_data(hass).debug_info_triggers[discovery_hash]["discovery_data"][ + ATTR_DISCOVERY_PAYLOAD + ] = discovery_payload -def remove_trigger_discovery_data(hass, discovery_hash): +def remove_trigger_discovery_data( + hass: HomeAssistant, discovery_hash: tuple[str, str] +) -> None: """Remove discovery data.""" - hass.data[DATA_MQTT_DEBUG_INFO]["triggers"].pop(discovery_hash) + get_mqtt_data(hass).debug_info_triggers.pop(discovery_hash) def _info_for_entity(hass: HomeAssistant, entity_id: str) -> dict[str, Any]: - mqtt_debug_info = hass.data[DATA_MQTT_DEBUG_INFO] - entity_info = mqtt_debug_info["entities"][entity_id] + entity_info = get_mqtt_data(hass).debug_info_entities[entity_id] subscriptions = [ { "topic": topic, @@ -205,9 +220,10 @@ def _info_for_entity(hass: HomeAssistant, entity_id: str) -> dict[str, Any]: } -def _info_for_trigger(hass: HomeAssistant, trigger_key: str) -> dict[str, Any]: - mqtt_debug_info = hass.data[DATA_MQTT_DEBUG_INFO] - trigger = mqtt_debug_info["triggers"][trigger_key] +def _info_for_trigger( + hass: HomeAssistant, trigger_key: tuple[str, str] +) -> dict[str, Any]: + trigger = get_mqtt_data(hass).debug_info_triggers[trigger_key] discovery_data = None if trigger["discovery_data"] is not None: discovery_data = { @@ -217,36 +233,39 @@ def _info_for_trigger(hass: HomeAssistant, trigger_key: str) -> dict[str, Any]: return {"discovery_data": discovery_data, "trigger_key": trigger_key} -def info_for_config_entry(hass): +def info_for_config_entry(hass: HomeAssistant) -> dict[str, list[Any]]: """Get debug info for all entities and triggers.""" - mqtt_info = {"entities": [], "triggers": []} - mqtt_debug_info = hass.data[DATA_MQTT_DEBUG_INFO] - for entity_id in mqtt_debug_info["entities"]: + mqtt_data = get_mqtt_data(hass) + mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []} + + for entity_id in mqtt_data.debug_info_entities: mqtt_info["entities"].append(_info_for_entity(hass, entity_id)) - for trigger_key in mqtt_debug_info["triggers"]: + for trigger_key in mqtt_data.debug_info_triggers: mqtt_info["triggers"].append(_info_for_trigger(hass, trigger_key)) return mqtt_info -def info_for_device(hass, device_id): +def info_for_device(hass: HomeAssistant, device_id: str) -> dict[str, list[Any]]: """Get debug info for a device.""" - mqtt_info = {"entities": [], "triggers": []} + + mqtt_data = get_mqtt_data(hass) + + mqtt_info: dict[str, list[Any]] = {"entities": [], "triggers": []} entity_registry = er.async_get(hass) entries = er.async_entries_for_device( entity_registry, device_id, include_disabled_entities=True ) - mqtt_debug_info = hass.data[DATA_MQTT_DEBUG_INFO] for entry in entries: - if entry.entity_id not in mqtt_debug_info["entities"]: + if entry.entity_id not in mqtt_data.debug_info_entities: continue mqtt_info["entities"].append(_info_for_entity(hass, entry.entity_id)) - for trigger_key, trigger in mqtt_debug_info["triggers"].items(): + for trigger_key, trigger in mqtt_data.debug_info_triggers.items(): if trigger["device_id"] != device_id: continue diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 141d93666c5..8022a6e91ae 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -865,6 +865,7 @@ class MqttDiscoveryUpdate(Entity): send_discovery_done(self.hass, self._discovery_data) if discovery_hash: + assert self._discovery_data is not None debug_info.add_entity_discovery_data( self.hass, self._discovery_data, self.entity_id ) diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index 2cff89f93a1..566f18bc791 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -2,10 +2,11 @@ from __future__ import annotations from ast import literal_eval +from collections import deque from collections.abc import Callable, Coroutine from dataclasses import dataclass, field import datetime as dt -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, TypedDict, Union import attr @@ -14,10 +15,11 @@ from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import template from homeassistant.helpers.entity import Entity from homeassistant.helpers.service_info.mqtt import ReceivePayloadType -from homeassistant.helpers.typing import ConfigType, TemplateVarsType +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType if TYPE_CHECKING: from .client import MQTT, Subscription + from .debug_info import TimestampedPublishMessage from .device_trigger import Trigger _SENTINEL = object() @@ -53,6 +55,28 @@ AsyncMessageCallbackType = Callable[[ReceiveMessage], Coroutine[Any, Any, None]] MessageCallbackType = Callable[[ReceiveMessage], None] +class SubscriptionDebugInfo(TypedDict): + """Class for holding subscription debug info.""" + + messages: deque[ReceiveMessage] + count: int + + +class EntityDebugInfo(TypedDict): + """Class for holding entity based debug info.""" + + subscriptions: dict[str, SubscriptionDebugInfo] + discovery_data: DiscoveryInfoType + transmitted: dict[str, dict[str, deque[TimestampedPublishMessage]]] + + +class TriggerDebugInfo(TypedDict): + """Class for holding trigger based debug info.""" + + device_id: str + discovery_data: DiscoveryInfoType + + class MqttCommandTemplate: """Class for rendering MQTT payload with command templates.""" @@ -187,6 +211,10 @@ class MqttData: client: MQTT | None = None config: ConfigType | None = None + debug_info_entities: dict[str, EntityDebugInfo] = field(default_factory=dict) + debug_info_triggers: dict[tuple[str, str], TriggerDebugInfo] = field( + default_factory=dict + ) device_triggers: dict[str, Trigger] = field(default_factory=dict) discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field( default_factory=dict diff --git a/tests/components/mqtt/test_common.py b/tests/components/mqtt/test_common.py index e2411f9fc6c..0ac7e64d1bb 100644 --- a/tests/components/mqtt/test_common.py +++ b/tests/components/mqtt/test_common.py @@ -1391,7 +1391,7 @@ async def help_test_entity_debug_info_remove( debug_info_data = debug_info.info_for_device(hass, device.id) assert len(debug_info_data["entities"]) == 0 assert len(debug_info_data["triggers"]) == 0 - assert entity_id not in hass.data[debug_info.DATA_MQTT_DEBUG_INFO]["entities"] + assert entity_id not in hass.data["mqtt"].debug_info_entities async def help_test_entity_debug_info_update_entity_id( @@ -1449,9 +1449,7 @@ async def help_test_entity_debug_info_update_entity_id( "subscriptions" ] assert len(debug_info_data["triggers"]) == 0 - assert ( - f"{domain}.test" not in hass.data[debug_info.DATA_MQTT_DEBUG_INFO]["entities"] - ) + assert f"{domain}.test" not in hass.data["mqtt"].debug_info_entities async def help_test_entity_disabled_by_default(