diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index c14266e296f..03e4093bb01 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -76,7 +76,6 @@ from .const import ( # noqa: F401 PLATFORMS, RELOADABLE_PLATFORMS, ) -from .mixins import MqttData from .models import ( # noqa: F401 MqttCommandTemplate, MqttValueTemplate, @@ -86,6 +85,7 @@ from .models import ( # noqa: F401 ) from .util import ( _VALID_QOS_SCHEMA, + get_mqtt_data, mqtt_config_entry_enabled, valid_publish_topic, valid_subscribe_topic, @@ -164,7 +164,7 @@ async def _async_setup_discovery( async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Start the MQTT protocol service.""" - mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + mqtt_data = get_mqtt_data(hass, True) conf: ConfigType | None = config.get(DOMAIN) @@ -249,7 +249,7 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) - Causes for this is config entry options changing. """ - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) assert (client := mqtt_data.client) is not None if (conf := mqtt_data.config) is None: @@ -267,7 +267,7 @@ async def _async_config_entry_updated(hass: HomeAssistant, entry: ConfigEntry) - async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | None: """Fetch fresh MQTT yaml config from the hass config when (re)loading the entry.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) if mqtt_data.reload_entry: hass_config = await conf_util.async_hass_config_yaml(hass) mqtt_data.config = CONFIG_SCHEMA_BASE(hass_config.get(DOMAIN, {})) @@ -307,7 +307,7 @@ async def async_fetch_config(hass: HomeAssistant, entry: ConfigEntry) -> dict | async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Load a config entry.""" - mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + mqtt_data = get_mqtt_data(hass, True) # Merge basic configuration, and add missing defaults for basic options if (conf := await async_fetch_config(hass, entry)) is None: @@ -593,7 +593,7 @@ def async_subscribe_connection_status( def is_connected(hass: HomeAssistant) -> bool: """Return if MQTT client is connected.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) assert mqtt_data.client is not None return mqtt_data.client.connected @@ -611,7 +611,7 @@ async def async_remove_config_entry_device( async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload MQTT dump and publish service when the config entry is unloaded.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) assert mqtt_data.client is not None mqtt_client = mqtt_data.client diff --git a/homeassistant/components/mqtt/client.py b/homeassistant/components/mqtt/client.py index 28887818133..7ede1e50494 100644 --- a/homeassistant/components/mqtt/client.py +++ b/homeassistant/components/mqtt/client.py @@ -46,7 +46,6 @@ from .const import ( CONF_KEEPALIVE, CONF_TLS_INSECURE, CONF_WILL_MESSAGE, - DATA_MQTT, DEFAULT_ENCODING, DEFAULT_QOS, MQTT_CONNECTED, @@ -61,15 +60,13 @@ from .models import ( ReceiveMessage, ReceivePayloadType, ) -from .util import mqtt_config_entry_enabled +from .util import get_mqtt_data, mqtt_config_entry_enabled if TYPE_CHECKING: # Only import for paho-mqtt type checking here, imports are done locally # because integrations should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt - from .mixins import MqttData - _LOGGER = logging.getLogger(__name__) @@ -100,11 +97,7 @@ async def async_publish( encoding: str | None = DEFAULT_ENCODING, ) -> None: """Publish message to a MQTT topic.""" - # Local import to avoid circular dependencies - # pylint: disable-next=import-outside-toplevel - from .mixins import MqttData - - mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + mqtt_data = get_mqtt_data(hass, True) if mqtt_data.client is None or not mqtt_config_entry_enabled(hass): raise HomeAssistantError( f"Cannot publish to topic '{topic}', MQTT is not enabled" @@ -190,11 +183,7 @@ async def async_subscribe( Call the return value to unsubscribe. """ - # Local import to avoid circular dependencies - # pylint: disable-next=import-outside-toplevel - from .mixins import MqttData - - mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + mqtt_data = get_mqtt_data(hass, True) if mqtt_data.client is None or not mqtt_config_entry_enabled(hass): raise HomeAssistantError( f"Cannot subscribe to topic '{topic}', MQTT is not enabled" @@ -332,7 +321,7 @@ class MQTT: # should be able to optionally rely on MQTT. import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel - self._mqtt_data: MqttData = hass.data[DATA_MQTT] + self._mqtt_data = get_mqtt_data(hass) self.hass = hass self.config_entry = config_entry diff --git a/homeassistant/components/mqtt/config_flow.py b/homeassistant/components/mqtt/config_flow.py index 12d97b41a74..afa2d98af2b 100644 --- a/homeassistant/components/mqtt/config_flow.py +++ b/homeassistant/components/mqtt/config_flow.py @@ -30,14 +30,12 @@ from .const import ( CONF_BIRTH_MESSAGE, CONF_BROKER, CONF_WILL_MESSAGE, - DATA_MQTT, DEFAULT_BIRTH, DEFAULT_DISCOVERY, DEFAULT_WILL, DOMAIN, ) -from .mixins import MqttData -from .util import MQTT_WILL_BIRTH_SCHEMA +from .util import MQTT_WILL_BIRTH_SCHEMA, get_mqtt_data MQTT_TIMEOUT = 5 @@ -165,7 +163,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Manage the MQTT broker configuration.""" - mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData()) + mqtt_data = get_mqtt_data(self.hass, True) errors = {} current_config = self.config_entry.data yaml_config = mqtt_data.config or {} @@ -216,7 +214,7 @@ class MQTTOptionsFlowHandler(config_entries.OptionsFlow): self, user_input: dict[str, Any] | None = None ) -> FlowResult: """Manage the MQTT options.""" - mqtt_data: MqttData = self.hass.data.setdefault(DATA_MQTT, MqttData()) + mqtt_data = get_mqtt_data(self.hass, True) errors = {} current_config = self.config_entry.data yaml_config = mqtt_data.config or {} @@ -351,7 +349,7 @@ def try_connection( import paho.mqtt.client as mqtt # pylint: disable=import-outside-toplevel # Get the config from configuration.yaml - mqtt_data: MqttData = hass.data.setdefault(DATA_MQTT, MqttData()) + mqtt_data = get_mqtt_data(hass, True) yaml_config = mqtt_data.config or {} entry_config = { CONF_BROKER: broker, diff --git a/homeassistant/components/mqtt/device_trigger.py b/homeassistant/components/mqtt/device_trigger.py index 7e37ed72821..f51731284cc 100644 --- a/homeassistant/components/mqtt/device_trigger.py +++ b/homeassistant/components/mqtt/device_trigger.py @@ -33,17 +33,16 @@ from .const import ( CONF_PAYLOAD, CONF_QOS, CONF_TOPIC, - DATA_MQTT, DOMAIN, ) from .discovery import MQTT_DISCOVERY_DONE from .mixins import ( MQTT_ENTITY_DEVICE_INFO_SCHEMA, - MqttData, MqttDiscoveryDeviceUpdate, send_discovery_done, update_device, ) +from .util import get_mqtt_data _LOGGER = logging.getLogger(__name__) @@ -203,7 +202,7 @@ class MqttDeviceTrigger(MqttDiscoveryDeviceUpdate): self.device_id = device_id self.discovery_data = discovery_data self.hass = hass - self._mqtt_data: MqttData = hass.data[DATA_MQTT] + self._mqtt_data = get_mqtt_data(hass) MqttDiscoveryDeviceUpdate.__init__( self, @@ -281,7 +280,7 @@ async def async_setup_trigger( async def async_removed_from_device(hass: HomeAssistant, device_id: str) -> None: """Handle Mqtt removed from a device.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) triggers = await async_get_triggers(hass, device_id) for trig in triggers: device_trigger: Trigger = mqtt_data.device_triggers.pop(trig[CONF_DISCOVERY_ID]) @@ -296,7 +295,7 @@ async def async_get_triggers( hass: HomeAssistant, device_id: str ) -> list[dict[str, str]]: """List device triggers for MQTT devices.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) triggers: list[dict[str, str]] = [] if not mqtt_data.device_triggers: @@ -325,7 +324,7 @@ async def async_attach_trigger( trigger_info: TriggerInfo, ) -> CALLBACK_TYPE: """Attach a trigger.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) device_id = config[CONF_DEVICE_ID] discovery_id = config[CONF_DISCOVERY_ID] diff --git a/homeassistant/components/mqtt/diagnostics.py b/homeassistant/components/mqtt/diagnostics.py index 2a6322cac63..173c583ca6a 100644 --- a/homeassistant/components/mqtt/diagnostics.py +++ b/homeassistant/components/mqtt/diagnostics.py @@ -16,7 +16,8 @@ from homeassistant.core import HomeAssistant, callback, split_entity_id from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.device_registry import DeviceEntry -from . import DATA_MQTT, MQTT, debug_info, is_connected +from . import debug_info, is_connected +from .util import get_mqtt_data REDACT_CONFIG = {CONF_PASSWORD, CONF_USERNAME} REDACT_STATE_DEVICE_TRACKER = {ATTR_LATITUDE, ATTR_LONGITUDE} @@ -43,7 +44,8 @@ def _async_get_diagnostics( device: DeviceEntry | None = None, ) -> dict[str, Any]: """Return diagnostics for a config entry.""" - mqtt_instance: MQTT = hass.data[DATA_MQTT].client + mqtt_instance = get_mqtt_data(hass).client + assert mqtt_instance is not None redacted_config = async_redact_data(mqtt_instance.conf, REDACT_CONFIG) diff --git a/homeassistant/components/mqtt/discovery.py b/homeassistant/components/mqtt/discovery.py index 65051ce54fc..ee0d0a1ac9a 100644 --- a/homeassistant/components/mqtt/discovery.py +++ b/homeassistant/components/mqtt/discovery.py @@ -7,7 +7,6 @@ import functools import logging import re import time -from typing import TYPE_CHECKING from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.core import HomeAssistant @@ -29,12 +28,9 @@ from .const import ( ATTR_DISCOVERY_TOPIC, CONF_AVAILABILITY, CONF_TOPIC, - DATA_MQTT, DOMAIN, ) - -if TYPE_CHECKING: - from .mixins import MqttData +from .util import get_mqtt_data _LOGGER = logging.getLogger(__name__) @@ -98,7 +94,7 @@ async def async_start( # noqa: C901 hass: HomeAssistant, discovery_topic, config_entry=None ) -> None: """Start MQTT Discovery.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) mqtt_integrations = {} async def async_discovery_message_received(msg): diff --git a/homeassistant/components/mqtt/mixins.py b/homeassistant/components/mqtt/mixins.py index 477be399e26..141d93666c5 100644 --- a/homeassistant/components/mqtt/mixins.py +++ b/homeassistant/components/mqtt/mixins.py @@ -4,10 +4,9 @@ from __future__ import annotations from abc import abstractmethod import asyncio from collections.abc import Callable, Coroutine -from dataclasses import dataclass, field from functools import partial import logging -from typing import TYPE_CHECKING, Any, Protocol, cast, final +from typing import Any, Protocol, cast, final import voluptuous as vol @@ -29,13 +28,7 @@ from homeassistant.const import ( CONF_UNIQUE_ID, CONF_VALUE_TEMPLATE, ) -from homeassistant.core import ( - CALLBACK_TYPE, - Event, - HomeAssistant, - async_get_hass, - callback, -) +from homeassistant.core import Event, HomeAssistant, async_get_hass, callback from homeassistant.helpers import ( config_validation as cv, device_registry as dr, @@ -60,7 +53,7 @@ from homeassistant.helpers.json import json_loads from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import debug_info, subscription -from .client import MQTT, Subscription, async_publish +from .client import async_publish from .const import ( ATTR_DISCOVERY_HASH, ATTR_DISCOVERY_PAYLOAD, @@ -69,7 +62,6 @@ from .const import ( CONF_ENCODING, CONF_QOS, CONF_TOPIC, - DATA_MQTT, DEFAULT_ENCODING, DEFAULT_PAYLOAD_AVAILABLE, DEFAULT_PAYLOAD_NOT_AVAILABLE, @@ -91,10 +83,7 @@ from .subscription import ( async_subscribe_topics, async_unsubscribe_topics, ) -from .util import mqtt_config_entry_enabled, valid_subscribe_topic - -if TYPE_CHECKING: - from .device_trigger import Trigger +from .util import get_mqtt_data, mqtt_config_entry_enabled, valid_subscribe_topic _LOGGER = logging.getLogger(__name__) @@ -272,27 +261,6 @@ def warn_for_legacy_schema(domain: str) -> Callable: return validator -@dataclass -class MqttData: - """Keep the MQTT entry data.""" - - client: MQTT | None = None - config: ConfigType | None = None - device_triggers: dict[str, Trigger] = field(default_factory=dict) - discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field( - default_factory=dict - ) - last_discovery: float = 0.0 - reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list) - reload_entry: bool = False - reload_handlers: dict[str, Callable[[], Coroutine[Any, Any, None]]] = field( - default_factory=dict - ) - reload_needed: bool = False - subscriptions_to_restore: list[Subscription] = field(default_factory=list) - updated_config: ConfigType = field(default_factory=dict) - - class SetupEntity(Protocol): """Protocol type for async_setup_entities.""" @@ -313,8 +281,7 @@ async def async_get_platform_config_from_yaml( config_yaml: ConfigType | None = None, ) -> list[ConfigType]: """Return a list of validated configurations for the domain.""" - - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) if config_yaml is None: config_yaml = mqtt_data.config if not config_yaml: @@ -331,7 +298,7 @@ async def async_setup_entry_helper( discovery_schema: vol.Schema, ) -> None: """Set up entity, automation or tag creation dynamically through MQTT discovery.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) async def async_discover(discovery_payload): """Discover and add an MQTT entity, automation or tag.""" @@ -363,7 +330,7 @@ async def async_setup_entry_helper( async def _async_setup_entities() -> None: """Set up MQTT items from configuration.yaml.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) if mqtt_data.updated_config: # The platform has been reloaded config_yaml = mqtt_data.updated_config @@ -395,7 +362,7 @@ async def async_setup_platform_helper( async_setup_entities: SetupEntity, ) -> None: """Help to set up the platform for manual configured MQTT entities.""" - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) if mqtt_data.reload_entry: _LOGGER.debug( "MQTT integration is %s, skipping setup of manually configured MQTT items while unloading the config entry", @@ -621,7 +588,7 @@ class MqttAvailability(Entity): @property def available(self) -> bool: """Return if the device is available.""" - mqtt_data: MqttData = self.hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(self.hass) assert mqtt_data.client is not None client = mqtt_data.client if not client.connected and not self.hass.is_stopping: @@ -844,7 +811,7 @@ class MqttDiscoveryUpdate(Entity): self._removed_from_hass = False if discovery_data is None: return - mqtt_data: MqttData = hass.data[DATA_MQTT] + mqtt_data = get_mqtt_data(hass) self._registry_hooks = mqtt_data.discovery_registry_hooks discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH] if discovery_hash in self._registry_hooks: diff --git a/homeassistant/components/mqtt/models.py b/homeassistant/components/mqtt/models.py index d40b882d81b..2cff89f93a1 100644 --- a/homeassistant/components/mqtt/models.py +++ b/homeassistant/components/mqtt/models.py @@ -3,17 +3,22 @@ from __future__ import annotations from ast import literal_eval from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field import datetime as dt -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union import attr from homeassistant.const import ATTR_ENTITY_ID, ATTR_NAME -from homeassistant.core import HomeAssistant, callback +from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.helpers import template from homeassistant.helpers.entity import Entity from homeassistant.helpers.service_info.mqtt import ReceivePayloadType -from homeassistant.helpers.typing import TemplateVarsType +from homeassistant.helpers.typing import ConfigType, TemplateVarsType + +if TYPE_CHECKING: + from .client import MQTT, Subscription + from .device_trigger import Trigger _SENTINEL = object() @@ -174,3 +179,24 @@ class MqttValueTemplate: return self._value_template.async_render_with_possible_json_value( payload, default, variables=values ) + + +@dataclass +class MqttData: + """Keep the MQTT entry data.""" + + client: MQTT | None = None + config: ConfigType | None = None + device_triggers: dict[str, Trigger] = field(default_factory=dict) + discovery_registry_hooks: dict[tuple[str, str], CALLBACK_TYPE] = field( + default_factory=dict + ) + last_discovery: float = 0.0 + reload_dispatchers: list[CALLBACK_TYPE] = field(default_factory=list) + reload_entry: bool = False + reload_handlers: dict[str, Callable[[], Coroutine[Any, Any, None]]] = field( + default_factory=dict + ) + reload_needed: bool = False + subscriptions_to_restore: list[Subscription] = field(default_factory=list) + updated_config: ConfigType = field(default_factory=dict) diff --git a/homeassistant/components/mqtt/util.py b/homeassistant/components/mqtt/util.py index 9ef30da7f3b..43734872e14 100644 --- a/homeassistant/components/mqtt/util.py +++ b/homeassistant/components/mqtt/util.py @@ -15,10 +15,12 @@ from .const import ( ATTR_QOS, ATTR_RETAIN, ATTR_TOPIC, + DATA_MQTT, DEFAULT_QOS, DEFAULT_RETAIN, DOMAIN, ) +from .models import MqttData def mqtt_config_entry_enabled(hass: HomeAssistant) -> bool | None: @@ -111,3 +113,10 @@ MQTT_WILL_BIRTH_SCHEMA = vol.Schema( }, required=True, ) + + +def get_mqtt_data(hass: HomeAssistant, ensure_exists: bool = False) -> MqttData: + """Return typed MqttData from hass.data[DATA_MQTT].""" + if ensure_exists: + return hass.data.setdefault(DATA_MQTT, MqttData()) + return hass.data[DATA_MQTT]